~singpolyma/haskell-gnutls

955b054ff43f3758c1c394b8f688bad742407c24 — Stephen Paul Weber 3 years ago 17b9279
Switch monad transformer stack to a type alias

Since we already allowed injecting any Session via runTLS or throwing any Error
via throwE, this does not reduce safety at all but improves ergonomics
considerably.

The only downside here is that we must say goodbye to our transitional MonadIO
instance.
1 files changed, 20 insertions(+), 39 deletions(-)

M lib/Network/Protocol/TLS/GNU.hs
M lib/Network/Protocol/TLS/GNU.hs => lib/Network/Protocol/TLS/GNU.hs +20 -39
@@ 20,7 20,6 @@ module Network.Protocol.TLS.GNU
	, Session
	, Error (..)
	, throwE
	, catchE
	, fromExceptT
	
	, runTLS


@@ 41,13 40,12 @@ module Network.Protocol.TLS.GNU
	, certificateCredentials
	) where

import           Control.Applicative (Applicative, pure, (<*>))
import qualified Control.Concurrent.MVar as M
import           Control.Monad (ap, when, foldM, foldM_)
import           Control.Monad (when, foldM, foldM_)
import           Control.Monad.Trans.Class (lift)
import qualified Control.Monad.Trans.Except as E
import qualified Control.Monad.Trans.Reader as R
import           Control.Monad.IO.Class (MonadIO, liftIO)
import           Control.Monad.IO.Class (liftIO)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Unsafe as B


@@ 62,7 60,7 @@ import qualified UnexceptionalIO.Trans as UIO

import qualified Network.Protocol.TLS.GNU.Foreign as F

data Error = Error Integer | IOError IOError
data Error = Error Integer
	deriving (Show)

globalInitMVar :: M.MVar ()


@@ 91,34 89,16 @@ data Session = Session
	, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
	}

newtype TLS a = TLS { unTLS :: E.ExceptT Error (R.ReaderT Session UIO) a }

instance Functor TLS where
	fmap f = TLS . fmap f . unTLS

instance Applicative TLS where
	pure = TLS . return
	(<*>) = ap

instance Monad TLS where
	return = TLS . return
	m >>= f = TLS $ unTLS m >>= unTLS . f

-- | This is a transitional instance and may be deprecated in the future
instance MonadIO TLS where
	liftIO = TLS . E.withExceptT IOError . UIO.fromIO' (userError . show)
type TLS a = E.ExceptT Error (R.ReaderT Session UIO) a

throwE :: Error -> TLS a
throwE = fromExceptT . E.throwE

catchE :: TLS a -> (Error -> TLS a) -> TLS a
catchE inner handler = TLS $ unTLS inner `E.catchE` (unTLS . handler)

fromExceptT :: E.ExceptT Error UIO a -> TLS a
fromExceptT = TLS . E.mapExceptT lift
fromExceptT = E.mapExceptT lift

runTLS :: (Unexceptional m) => Session -> TLS a -> m (Either Error a)
runTLS s tls = UIO.lift $ R.runReaderT (E.runExceptT (unTLS tls)) s
runTLS s tls = UIO.lift $ R.runReaderT (E.runExceptT tls) s

runClient :: Transport -> TLS a -> IO (Either Error a)
runClient transport tls = do


@@ 149,18 129,18 @@ newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do
		return (Session fp creds)

getSession :: TLS Session
getSession = TLS $ lift R.ask
getSession = lift R.ask

handshake :: TLS ()
handshake = withSession F.gnutls_handshake >>= checkRC
handshake = unsafeWithSession F.gnutls_handshake >>= checkRC

rehandshake :: TLS ()
rehandshake = withSession F.gnutls_rehandshake >>= checkRC
rehandshake = unsafeWithSession F.gnutls_rehandshake >>= checkRC

putBytes :: BL.ByteString -> TLS ()
putBytes = putChunks . BL.toChunks where
	putChunks chunks = do
		maybeErr <- withSession $ \s -> foldM (putChunk s) Nothing chunks
		maybeErr <- unsafeWithSession $ \s -> foldM (putChunk s) Nothing chunks
		case maybeErr of
			Nothing -> return ()
			Just err -> throwE $ mapError $ fromIntegral err


@@ 179,7 159,7 @@ putBytes = putChunks . BL.toChunks where

getBytes :: Integer -> TLS BL.ByteString
getBytes count = do
	(mbytes, len) <- withSession $ \s ->
	(mbytes, len) <- unsafeWithSession $ \s ->
		F.allocaBytes (fromInteger count) $ \ptr -> do
		len <- F.gnutls_record_recv s ptr (fromInteger count)
		bytes <- if len >= 0


@@ 194,7 174,7 @@ getBytes count = do
		Nothing   -> throwE $ mapError $ fromIntegral len

checkPending :: TLS Integer
checkPending = withSession $ \s -> do
checkPending = unsafeWithSession $ \s -> do
	pending <- F.gnutls_record_check_pending s
	return $ toInteger pending



@@ 227,31 207,32 @@ data Credentials = Credentials F.CredentialsType (F.ForeignPtr F.Credentials)

setCredentials :: Credentials -> TLS ()
setCredentials (Credentials ctype fp) = do
	rc <- withSession $ \s ->
	rc <- unsafeWithSession $ \s ->
		F.withForeignPtr fp $ \ptr -> do
		F.gnutls_credentials_set s ctype ptr
	
	s <- getSession
	if F.unRC rc == 0
		then liftIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ())))
		then UIO.unsafeFromIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ())))
		else checkRC rc

certificateCredentials :: TLS Credentials
certificateCredentials = do
	(rc, ptr) <- liftIO $ F.alloca $ \ptr -> do
	(rc, ptr) <- UIO.unsafeFromIO $ F.alloca $ \ptr -> do
		rc <- F.gnutls_certificate_allocate_credentials ptr
		ptr' <- if F.unRC rc < 0
			then return F.nullPtr
			else F.peek ptr
		return (rc, ptr')
	checkRC rc
	fp <- liftIO $ F.newForeignPtr F.gnutls_certificate_free_credentials_funptr ptr
	fp <- UIO.unsafeFromIO $ F.newForeignPtr F.gnutls_certificate_free_credentials_funptr ptr
	return $ Credentials (F.CredentialsType 1) fp

withSession :: (F.Session -> IO a) -> TLS a
withSession io = do
-- | This must only be called with IO actions that do not throw NonPseudoException
unsafeWithSession :: (F.Session -> IO a) -> TLS a
unsafeWithSession io = do
	s <- getSession
	liftIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session
	UIO.unsafeFromIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session

checkRC :: F.ReturnCode -> TLS ()
checkRC (F.ReturnCode x) = when (x < 0) $ throwE $ mapError x