@@ 19,6 19,9 @@ module Network.Protocol.TLS.GNU
( TLS
, Session
, Error (..)
+ , throwE
+ , catchE
+ , fromExceptT
, runTLS
, runClient
@@ 42,7 45,7 @@ import Control.Applicative (Applicative, pure, (<*>))
import qualified Control.Concurrent.MVar as M
import Control.Monad (ap, when, foldM, foldM_)
import Control.Monad.Trans.Class (lift)
-import Control.Monad.Trans.Except
+import qualified Control.Monad.Trans.Except as E
import qualified Control.Monad.Trans.Reader as R
import Control.Monad.IO.Class (MonadIO, liftIO)
import qualified Data.ByteString as B
@@ 66,11 69,11 @@ globalInitMVar :: M.MVar ()
{-# NOINLINE globalInitMVar #-}
globalInitMVar = unsafePerformIO $ M.newMVar ()
-globalInit :: ExceptT Error IO ()
+globalInit :: E.ExceptT Error IO ()
globalInit = do
let init_ = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init
F.ReturnCode rc <- liftIO init_
- when (rc < 0) $ throwE $ mapError rc
+ when (rc < 0) $ E.throwE $ mapError rc
globalDeinit :: IO ()
globalDeinit = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_deinit
@@ 88,7 91,7 @@ data Session = Session
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
}
-newtype TLS a = TLS { unTLS :: ExceptT Error (R.ReaderT Session UIO) a }
+newtype TLS a = TLS { unTLS :: E.ExceptT Error (R.ReaderT Session UIO) a }
instance Functor TLS where
fmap f = TLS . fmap f . unTLS
@@ 103,10 106,19 @@ instance Monad TLS where
-- | This is a transitional instance and may be deprecated in the future
instance MonadIO TLS where
- liftIO = TLS . withExceptT IOError . UIO.fromIO' (userError . show)
+ liftIO = TLS . E.withExceptT IOError . UIO.fromIO' (userError . show)
+
+throwE :: Error -> TLS a
+throwE = fromExceptT . E.throwE
+
+catchE :: TLS a -> (Error -> TLS a) -> TLS a
+catchE inner handler = TLS $ unTLS inner `E.catchE` (unTLS . handler)
+
+fromExceptT :: E.ExceptT Error UIO a -> TLS a
+fromExceptT = TLS . E.mapExceptT lift
runTLS :: (Unexceptional m) => Session -> TLS a -> m (Either Error a)
-runTLS s tls = UIO.lift $ R.runReaderT (runExceptT (unTLS tls)) s
+runTLS s tls = UIO.lift $ R.runReaderT (E.runExceptT (unTLS tls)) s
runClient :: Transport -> TLS a -> IO (Either Error a)
runClient transport tls = do
@@ 116,10 128,10 @@ runClient transport tls = do
Right session -> runTLS session tls
newSession :: Transport -> F.ConnectionEnd -> IO (Either Error Session)
-newSession transport end = F.alloca $ \sPtr -> runExceptT $ do
+newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do
globalInit
F.ReturnCode rc <- liftIO $ F.gnutls_init sPtr end
- when (rc < 0) $ throwE $ mapError rc
+ when (rc < 0) $ E.throwE $ mapError rc
liftIO $ do
ptr <- F.peek sPtr
let session = F.Session ptr
@@ 151,7 163,7 @@ putBytes = putChunks . BL.toChunks where
maybeErr <- withSession $ \s -> foldM (putChunk s) Nothing chunks
case maybeErr of
Nothing -> return ()
- Just err -> TLS $ mapExceptT lift $ throwE $ mapError $ fromIntegral err
+ Just err -> throwE $ mapError $ fromIntegral err
putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where
loop ptr len = do
@@ 179,7 191,7 @@ getBytes count = do
case mbytes of
Just bytes -> return bytes
- Nothing -> TLS $ mapExceptT lift $ throwE $ mapError $ fromIntegral len
+ Nothing -> throwE $ mapError $ fromIntegral len
checkPending :: TLS Integer
checkPending = withSession $ \s -> do
@@ 242,7 254,7 @@ withSession io = do
liftIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
checkRC :: F.ReturnCode -> TLS ()
-checkRC (F.ReturnCode x) = when (x < 0) $ TLS $ mapExceptT lift $ throwE $ mapError x
+checkRC (F.ReturnCode x) = when (x < 0) $ throwE $ mapError x
mapError :: F.CInt -> Error
mapError = Error . toInteger