~singpolyma/haskell-gnutls

06a662e63ab0345b044655410bb3c58f87cb0491 — Stephen Paul Weber 3 years ago ceac331
Switch TLS to TLST to allow any Unexceptional base monad
1 files changed, 30 insertions(+), 34 deletions(-)

M lib/Network/Protocol/TLS/GNU.hs
M lib/Network/Protocol/TLS/GNU.hs => lib/Network/Protocol/TLS/GNU.hs +30 -34
@@ 17,10 17,9 @@

module Network.Protocol.TLS.GNU
	( TLS
	, TLST
	, Session
	, Error (..)
	, throwE
	, fromExceptT
	
	, runTLS
	, runTLS'


@@ 46,7 45,6 @@ import           Control.Monad (when, foldM, foldM_)
import           Control.Monad.Trans.Class (lift)
import qualified Control.Monad.Trans.Except as E
import qualified Control.Monad.Trans.Reader as R
import           Control.Monad.IO.Class (liftIO)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Unsafe as B


@@ 56,7 54,7 @@ import qualified Foreign.C as F
import           Foreign.Concurrent as FC
import qualified System.IO as IO
import           System.IO.Unsafe (unsafePerformIO)
import           UnexceptionalIO.Trans (UIO, Unexceptional)
import           UnexceptionalIO.Trans (Unexceptional)
import qualified UnexceptionalIO.Trans as UIO

import qualified Network.Protocol.TLS.GNU.Foreign as F


@@ 68,10 66,10 @@ globalInitMVar :: M.MVar ()
{-# NOINLINE globalInitMVar #-}
globalInitMVar = unsafePerformIO $ M.newMVar ()

globalInit :: E.ExceptT Error IO ()
globalInit :: (Unexceptional m) => E.ExceptT Error m ()
globalInit = do
	let init_ = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init
	F.ReturnCode rc <- liftIO init_
	F.ReturnCode rc <- UIO.unsafeFromIO init_
	when (rc < 0) $ E.throwE $ mapError rc

globalDeinit :: IO ()


@@ 90,33 88,31 @@ data Session = Session
	, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
	}

type TLS a = E.ExceptT Error (R.ReaderT Session UIO) a
type TLS a = TLST IO a
type TLST m a = E.ExceptT Error (R.ReaderT Session m) a

throwE :: Error -> TLS a
throwE = fromExceptT . E.throwE

fromExceptT :: E.ExceptT Error UIO a -> TLS a
fromExceptT = E.mapExceptT lift

runTLS :: (Unexceptional m) => Session -> TLS a -> m (Either Error a)
runTLS :: (Unexceptional m) => Session -> TLST m a -> m (Either Error a)
runTLS s = E.runExceptT . runTLS' s

runTLS' :: (Unexceptional m) => Session -> TLS a -> E.ExceptT Error m a
runTLS' s = E.mapExceptT (UIO.lift . flip R.runReaderT s)
runTLS' :: Session -> TLST m a -> E.ExceptT Error m a
runTLS' s = E.mapExceptT (flip R.runReaderT s)

runClient :: Transport -> TLS a -> IO (Either Error a)
runClient :: (Unexceptional m) => Transport -> TLST m a -> m (Either Error a)
runClient transport tls = do
	eitherSession <- newSession transport (F.ConnectionEnd 2)
	case eitherSession of
		Left err -> return (Left err)
		Right session -> runTLS session tls

newSession :: Transport -> F.ConnectionEnd -> IO (Either Error Session)
newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do
newSession :: (Unexceptional m) =>
	   Transport
	-> F.ConnectionEnd
	-> m (Either Error Session)
newSession transport end = UIO.unsafeFromIO . F.alloca $ \sPtr -> E.runExceptT $ do
	globalInit
	F.ReturnCode rc <- liftIO $ F.gnutls_init sPtr end
	F.ReturnCode rc <- UIO.unsafeFromIO $ F.gnutls_init sPtr end
	when (rc < 0) $ E.throwE $ mapError rc
	liftIO $ do
	UIO.unsafeFromIO $ do
		ptr <- F.peek sPtr
		let session = F.Session ptr
		push <- F.wrapTransportFunc (pushImpl transport)


@@ 132,22 128,22 @@ newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do
			F.freeHaskellFunPtr pull
		return (Session fp creds)

getSession :: TLS Session
getSession :: (Monad m) => TLST m Session
getSession = lift R.ask

handshake :: TLS ()
handshake :: (Unexceptional m) => TLST m ()
handshake = unsafeWithSession F.gnutls_handshake >>= checkRC

rehandshake :: TLS ()
rehandshake :: (Unexceptional m) => TLST m ()
rehandshake = unsafeWithSession F.gnutls_rehandshake >>= checkRC

putBytes :: BL.ByteString -> TLS ()
putBytes :: (Unexceptional m) => BL.ByteString -> TLST m ()
putBytes = putChunks . BL.toChunks where
	putChunks chunks = do
		maybeErr <- unsafeWithSession $ \s -> foldM (putChunk s) Nothing chunks
		case maybeErr of
			Nothing -> return ()
			Just err -> throwE $ mapError $ fromIntegral err
			Just err -> E.throwE $ mapError $ fromIntegral err
	
	putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where
		loop ptr len = do


@@ 161,7 157,7 @@ putBytes = putChunks . BL.toChunks where
	
	putChunk _ err _ = return err

getBytes :: Integer -> TLS BL.ByteString
getBytes :: (Unexceptional m) => Integer -> TLST m BL.ByteString
getBytes count = do
	(mbytes, len) <- unsafeWithSession $ \s ->
		F.allocaBytes (fromInteger count) $ \ptr -> do


@@ 175,9 171,9 @@ getBytes count = do
	
	case mbytes of
		Just bytes -> return bytes
		Nothing   -> throwE $ mapError $ fromIntegral len
		Nothing   -> E.throwE $ mapError $ fromIntegral len

checkPending :: TLS Integer
checkPending :: (Unexceptional m) => TLST m Integer
checkPending = unsafeWithSession $ \s -> do
	pending <- F.gnutls_record_check_pending s
	return $ toInteger pending


@@ 209,7 205,7 @@ handleTransport h = Transport (BL.hPut h) (BL.hGet h . fromInteger)

data Credentials = Credentials F.CredentialsType (F.ForeignPtr F.Credentials)

setCredentials :: Credentials -> TLS ()
setCredentials :: (Unexceptional m) => Credentials -> TLST m ()
setCredentials (Credentials ctype fp) = do
	rc <- unsafeWithSession $ \s ->
		F.withForeignPtr fp $ \ptr -> do


@@ 220,7 216,7 @@ setCredentials (Credentials ctype fp) = do
		then UIO.unsafeFromIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ())))
		else checkRC rc

certificateCredentials :: TLS Credentials
certificateCredentials :: (Unexceptional m) => TLST m Credentials
certificateCredentials = do
	(rc, ptr) <- UIO.unsafeFromIO $ F.alloca $ \ptr -> do
		rc <- F.gnutls_certificate_allocate_credentials ptr


@@ 233,13 229,13 @@ certificateCredentials = do
	return $ Credentials (F.CredentialsType 1) fp

-- | This must only be called with IO actions that do not throw NonPseudoException
unsafeWithSession :: (F.Session -> IO a) -> TLS a
unsafeWithSession :: (Unexceptional m) => (F.Session -> IO a) -> TLST m a
unsafeWithSession io = do
	s <- getSession
	UIO.unsafeFromIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session

checkRC :: F.ReturnCode -> TLS ()
checkRC (F.ReturnCode x) = when (x < 0) $ throwE $ mapError x
checkRC :: (Monad m) => F.ReturnCode -> TLST m ()
checkRC (F.ReturnCode x) = when (x < 0) $ E.throwE $ mapError x

mapError :: F.CInt -> Error
mapError = Error . toInteger