@@ 17,10 17,9 @@
module Network.Protocol.TLS.GNU
( TLS
+ , TLST
, Session
, Error (..)
- , throwE
- , fromExceptT
, runTLS
, runTLS'
@@ 46,7 45,6 @@ import Control.Monad (when, foldM, foldM_)
import Control.Monad.Trans.Class (lift)
import qualified Control.Monad.Trans.Except as E
import qualified Control.Monad.Trans.Reader as R
-import Control.Monad.IO.Class (liftIO)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Unsafe as B
@@ 56,7 54,7 @@ import qualified Foreign.C as F
import Foreign.Concurrent as FC
import qualified System.IO as IO
import System.IO.Unsafe (unsafePerformIO)
-import UnexceptionalIO.Trans (UIO, Unexceptional)
+import UnexceptionalIO.Trans (Unexceptional)
import qualified UnexceptionalIO.Trans as UIO
import qualified Network.Protocol.TLS.GNU.Foreign as F
@@ 68,10 66,10 @@ globalInitMVar :: M.MVar ()
{-# NOINLINE globalInitMVar #-}
globalInitMVar = unsafePerformIO $ M.newMVar ()
-globalInit :: E.ExceptT Error IO ()
+globalInit :: (Unexceptional m) => E.ExceptT Error m ()
globalInit = do
let init_ = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init
- F.ReturnCode rc <- liftIO init_
+ F.ReturnCode rc <- UIO.unsafeFromIO init_
when (rc < 0) $ E.throwE $ mapError rc
globalDeinit :: IO ()
@@ 90,33 88,31 @@ data Session = Session
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
}
-type TLS a = E.ExceptT Error (R.ReaderT Session UIO) a
+type TLS a = TLST IO a
+type TLST m a = E.ExceptT Error (R.ReaderT Session m) a
-throwE :: Error -> TLS a
-throwE = fromExceptT . E.throwE
-
-fromExceptT :: E.ExceptT Error UIO a -> TLS a
-fromExceptT = E.mapExceptT lift
-
-runTLS :: (Unexceptional m) => Session -> TLS a -> m (Either Error a)
+runTLS :: (Unexceptional m) => Session -> TLST m a -> m (Either Error a)
runTLS s = E.runExceptT . runTLS' s
-runTLS' :: (Unexceptional m) => Session -> TLS a -> E.ExceptT Error m a
-runTLS' s = E.mapExceptT (UIO.lift . flip R.runReaderT s)
+runTLS' :: Session -> TLST m a -> E.ExceptT Error m a
+runTLS' s = E.mapExceptT (flip R.runReaderT s)
-runClient :: Transport -> TLS a -> IO (Either Error a)
+runClient :: (Unexceptional m) => Transport -> TLST m a -> m (Either Error a)
runClient transport tls = do
eitherSession <- newSession transport (F.ConnectionEnd 2)
case eitherSession of
Left err -> return (Left err)
Right session -> runTLS session tls
-newSession :: Transport -> F.ConnectionEnd -> IO (Either Error Session)
-newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do
+newSession :: (Unexceptional m) =>
+ Transport
+ -> F.ConnectionEnd
+ -> m (Either Error Session)
+newSession transport end = UIO.unsafeFromIO . F.alloca $ \sPtr -> E.runExceptT $ do
globalInit
- F.ReturnCode rc <- liftIO $ F.gnutls_init sPtr end
+ F.ReturnCode rc <- UIO.unsafeFromIO $ F.gnutls_init sPtr end
when (rc < 0) $ E.throwE $ mapError rc
- liftIO $ do
+ UIO.unsafeFromIO $ do
ptr <- F.peek sPtr
let session = F.Session ptr
push <- F.wrapTransportFunc (pushImpl transport)
@@ 132,22 128,22 @@ newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do
F.freeHaskellFunPtr pull
return (Session fp creds)
-getSession :: TLS Session
+getSession :: (Monad m) => TLST m Session
getSession = lift R.ask
-handshake :: TLS ()
+handshake :: (Unexceptional m) => TLST m ()
handshake = unsafeWithSession F.gnutls_handshake >>= checkRC
-rehandshake :: TLS ()
+rehandshake :: (Unexceptional m) => TLST m ()
rehandshake = unsafeWithSession F.gnutls_rehandshake >>= checkRC
-putBytes :: BL.ByteString -> TLS ()
+putBytes :: (Unexceptional m) => BL.ByteString -> TLST m ()
putBytes = putChunks . BL.toChunks where
putChunks chunks = do
maybeErr <- unsafeWithSession $ \s -> foldM (putChunk s) Nothing chunks
case maybeErr of
Nothing -> return ()
- Just err -> throwE $ mapError $ fromIntegral err
+ Just err -> E.throwE $ mapError $ fromIntegral err
putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where
loop ptr len = do
@@ 161,7 157,7 @@ putBytes = putChunks . BL.toChunks where
putChunk _ err _ = return err
-getBytes :: Integer -> TLS BL.ByteString
+getBytes :: (Unexceptional m) => Integer -> TLST m BL.ByteString
getBytes count = do
(mbytes, len) <- unsafeWithSession $ \s ->
F.allocaBytes (fromInteger count) $ \ptr -> do
@@ 175,9 171,9 @@ getBytes count = do
case mbytes of
Just bytes -> return bytes
- Nothing -> throwE $ mapError $ fromIntegral len
+ Nothing -> E.throwE $ mapError $ fromIntegral len
-checkPending :: TLS Integer
+checkPending :: (Unexceptional m) => TLST m Integer
checkPending = unsafeWithSession $ \s -> do
pending <- F.gnutls_record_check_pending s
return $ toInteger pending
@@ 209,7 205,7 @@ handleTransport h = Transport (BL.hPut h) (BL.hGet h . fromInteger)
data Credentials = Credentials F.CredentialsType (F.ForeignPtr F.Credentials)
-setCredentials :: Credentials -> TLS ()
+setCredentials :: (Unexceptional m) => Credentials -> TLST m ()
setCredentials (Credentials ctype fp) = do
rc <- unsafeWithSession $ \s ->
F.withForeignPtr fp $ \ptr -> do
@@ 220,7 216,7 @@ setCredentials (Credentials ctype fp) = do
then UIO.unsafeFromIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ())))
else checkRC rc
-certificateCredentials :: TLS Credentials
+certificateCredentials :: (Unexceptional m) => TLST m Credentials
certificateCredentials = do
(rc, ptr) <- UIO.unsafeFromIO $ F.alloca $ \ptr -> do
rc <- F.gnutls_certificate_allocate_credentials ptr
@@ 233,13 229,13 @@ certificateCredentials = do
return $ Credentials (F.CredentialsType 1) fp
-- | This must only be called with IO actions that do not throw NonPseudoException
-unsafeWithSession :: (F.Session -> IO a) -> TLS a
+unsafeWithSession :: (Unexceptional m) => (F.Session -> IO a) -> TLST m a
unsafeWithSession io = do
s <- getSession
UIO.unsafeFromIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
-checkRC :: F.ReturnCode -> TLS ()
-checkRC (F.ReturnCode x) = when (x < 0) $ throwE $ mapError x
+checkRC :: (Monad m) => F.ReturnCode -> TLST m ()
+checkRC (F.ReturnCode x) = when (x < 0) $ E.throwE $ mapError x
mapError :: F.CInt -> Error
mapError = Error . toInteger