@@ 20,7 20,6 @@ module Network.Protocol.TLS.GNU
, Session
, Error (..)
, throwE
- , catchE
, fromExceptT
, runTLS
@@ 41,13 40,12 @@ module Network.Protocol.TLS.GNU
, certificateCredentials
) where
-import Control.Applicative (Applicative, pure, (<*>))
import qualified Control.Concurrent.MVar as M
-import Control.Monad (ap, when, foldM, foldM_)
+import Control.Monad (when, foldM, foldM_)
import Control.Monad.Trans.Class (lift)
import qualified Control.Monad.Trans.Except as E
import qualified Control.Monad.Trans.Reader as R
-import Control.Monad.IO.Class (MonadIO, liftIO)
+import Control.Monad.IO.Class (liftIO)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Unsafe as B
@@ 62,7 60,7 @@ import qualified UnexceptionalIO.Trans as UIO
import qualified Network.Protocol.TLS.GNU.Foreign as F
-data Error = Error Integer | IOError IOError
+data Error = Error Integer
deriving (Show)
globalInitMVar :: M.MVar ()
@@ 91,34 89,16 @@ data Session = Session
, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
}
-newtype TLS a = TLS { unTLS :: E.ExceptT Error (R.ReaderT Session UIO) a }
-
-instance Functor TLS where
- fmap f = TLS . fmap f . unTLS
-
-instance Applicative TLS where
- pure = TLS . return
- (<*>) = ap
-
-instance Monad TLS where
- return = TLS . return
- m >>= f = TLS $ unTLS m >>= unTLS . f
-
--- | This is a transitional instance and may be deprecated in the future
-instance MonadIO TLS where
- liftIO = TLS . E.withExceptT IOError . UIO.fromIO' (userError . show)
+type TLS a = E.ExceptT Error (R.ReaderT Session UIO) a
throwE :: Error -> TLS a
throwE = fromExceptT . E.throwE
-catchE :: TLS a -> (Error -> TLS a) -> TLS a
-catchE inner handler = TLS $ unTLS inner `E.catchE` (unTLS . handler)
-
fromExceptT :: E.ExceptT Error UIO a -> TLS a
-fromExceptT = TLS . E.mapExceptT lift
+fromExceptT = E.mapExceptT lift
runTLS :: (Unexceptional m) => Session -> TLS a -> m (Either Error a)
-runTLS s tls = UIO.lift $ R.runReaderT (E.runExceptT (unTLS tls)) s
+runTLS s tls = UIO.lift $ R.runReaderT (E.runExceptT tls) s
runClient :: Transport -> TLS a -> IO (Either Error a)
runClient transport tls = do
@@ 149,18 129,18 @@ newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do
return (Session fp creds)
getSession :: TLS Session
-getSession = TLS $ lift R.ask
+getSession = lift R.ask
handshake :: TLS ()
-handshake = withSession F.gnutls_handshake >>= checkRC
+handshake = unsafeWithSession F.gnutls_handshake >>= checkRC
rehandshake :: TLS ()
-rehandshake = withSession F.gnutls_rehandshake >>= checkRC
+rehandshake = unsafeWithSession F.gnutls_rehandshake >>= checkRC
putBytes :: BL.ByteString -> TLS ()
putBytes = putChunks . BL.toChunks where
putChunks chunks = do
- maybeErr <- withSession $ \s -> foldM (putChunk s) Nothing chunks
+ maybeErr <- unsafeWithSession $ \s -> foldM (putChunk s) Nothing chunks
case maybeErr of
Nothing -> return ()
Just err -> throwE $ mapError $ fromIntegral err
@@ 179,7 159,7 @@ putBytes = putChunks . BL.toChunks where
getBytes :: Integer -> TLS BL.ByteString
getBytes count = do
- (mbytes, len) <- withSession $ \s ->
+ (mbytes, len) <- unsafeWithSession $ \s ->
F.allocaBytes (fromInteger count) $ \ptr -> do
len <- F.gnutls_record_recv s ptr (fromInteger count)
bytes <- if len >= 0
@@ 194,7 174,7 @@ getBytes count = do
Nothing -> throwE $ mapError $ fromIntegral len
checkPending :: TLS Integer
-checkPending = withSession $ \s -> do
+checkPending = unsafeWithSession $ \s -> do
pending <- F.gnutls_record_check_pending s
return $ toInteger pending
@@ 227,31 207,32 @@ data Credentials = Credentials F.CredentialsType (F.ForeignPtr F.Credentials)
setCredentials :: Credentials -> TLS ()
setCredentials (Credentials ctype fp) = do
- rc <- withSession $ \s ->
+ rc <- unsafeWithSession $ \s ->
F.withForeignPtr fp $ \ptr -> do
F.gnutls_credentials_set s ctype ptr
s <- getSession
if F.unRC rc == 0
- then liftIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ())))
+ then UIO.unsafeFromIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ())))
else checkRC rc
certificateCredentials :: TLS Credentials
certificateCredentials = do
- (rc, ptr) <- liftIO $ F.alloca $ \ptr -> do
+ (rc, ptr) <- UIO.unsafeFromIO $ F.alloca $ \ptr -> do
rc <- F.gnutls_certificate_allocate_credentials ptr
ptr' <- if F.unRC rc < 0
then return F.nullPtr
else F.peek ptr
return (rc, ptr')
checkRC rc
- fp <- liftIO $ F.newForeignPtr F.gnutls_certificate_free_credentials_funptr ptr
+ fp <- UIO.unsafeFromIO $ F.newForeignPtr F.gnutls_certificate_free_credentials_funptr ptr
return $ Credentials (F.CredentialsType 1) fp
-withSession :: (F.Session -> IO a) -> TLS a
-withSession io = do
+-- | This must only be called with IO actions that do not throw NonPseudoException
+unsafeWithSession :: (F.Session -> IO a) -> TLS a
+unsafeWithSession io = do
s <- getSession
- liftIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
+ UIO.unsafeFromIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
checkRC :: F.ReturnCode -> TLS ()
checkRC (F.ReturnCode x) = when (x < 0) $ throwE $ mapError x