~singpolyma/haskell-gnutls

28825761260433f89a6309283276a4a267e43894 — Stephen Paul Weber 3 years ago b32c661
Switch from monads-tf and custom transformer to ExceptT

When this code was written ExceptT didn't exist yet, but there's no reason to
vendor a custom implementation of it any longer.  We're taking very little
advantage of the monads-tf features so just remove that dependency (and the
language extension reliance that goes with it).
3 files changed, 15 insertions(+), 105 deletions(-)

M gnutls.cabal
M lib/Network/Protocol/TLS/GNU.hs
D lib/Network/Protocol/TLS/GNU/ErrorT.hs
M gnutls.cabal => gnutls.cabal +1 -3
@@ 35,8 35,7 @@ library
  build-depends:
      base >= 4.0 && < 5.0
    , bytestring >= 0.9
    , transformers >= 0.2
    , monads-tf >= 0.1 && < 0.2
    , transformers >= 0.4.0.0

  extra-libraries: gnutls
  pkgconfig-depends: gnutls


@@ 45,5 44,4 @@ library
    Network.Protocol.TLS.GNU

  other-modules:
    Network.Protocol.TLS.GNU.ErrorT
    Network.Protocol.TLS.GNU.Foreign

M lib/Network/Protocol/TLS/GNU.hs => lib/Network/Protocol/TLS/GNU.hs +14 -20
@@ 41,10 41,10 @@ module Network.Protocol.TLS.GNU
import           Control.Applicative (Applicative, pure, (<*>))
import qualified Control.Concurrent.MVar as M
import           Control.Monad (ap, when, foldM, foldM_)
import qualified Control.Monad.Error as E
import           Control.Monad.Error (ErrorType)
import qualified Control.Monad.Reader as R
import           Control.Monad.Trans (MonadIO, liftIO)
import           Control.Monad.Trans.Class (lift)
import           Control.Monad.Trans.Except
import qualified Control.Monad.Trans.Reader as R
import           Control.Monad.IO.Class (MonadIO, liftIO)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Unsafe as B


@@ 55,7 55,6 @@ import           Foreign.Concurrent as FC
import qualified System.IO as IO
import           System.IO.Unsafe (unsafePerformIO)

import           Network.Protocol.TLS.GNU.ErrorT
import qualified Network.Protocol.TLS.GNU.Foreign as F

data Error = Error Integer


@@ 65,11 64,11 @@ globalInitMVar :: M.MVar ()
{-# NOINLINE globalInitMVar #-}
globalInitMVar = unsafePerformIO $ M.newMVar ()

globalInit :: ErrorT Error IO ()
globalInit :: ExceptT Error IO ()
globalInit = do
	let init_ = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init
	F.ReturnCode rc <- liftIO init_
	when (rc < 0) $ E.throwError $ mapError rc
	when (rc < 0) $ throwE $ mapError rc

globalDeinit :: IO ()
globalDeinit = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_deinit


@@ 87,7 86,7 @@ data Session = Session
	, sessionCredentials :: IORef [F.ForeignPtr F.Credentials]
	}

newtype TLS a = TLS { unTLS :: ErrorT Error (R.ReaderT Session IO) a }
newtype TLS a = TLS { unTLS :: ExceptT Error (R.ReaderT Session IO) a }

instance Functor TLS where
	fmap f = TLS . fmap f . unTLS


@@ 103,13 102,8 @@ instance Monad TLS where
instance MonadIO TLS where
	liftIO = TLS . liftIO

instance E.MonadError TLS where
	type ErrorType TLS = Error
	throwError = TLS . E.throwError
	catchError m h = TLS $ E.catchError (unTLS m) (unTLS . h)

runTLS :: Session -> TLS a -> IO (Either Error a)
runTLS s tls = R.runReaderT (runErrorT (unTLS tls)) s
runTLS s tls = R.runReaderT (runExceptT (unTLS tls)) s

runClient :: Transport -> TLS a -> IO (Either Error a)
runClient transport tls = do


@@ 119,10 113,10 @@ runClient transport tls = do
		Right session -> runTLS session tls

newSession :: Transport -> F.ConnectionEnd -> IO (Either Error Session)
newSession transport end = F.alloca $ \sPtr -> runErrorT $ do
newSession transport end = F.alloca $ \sPtr -> runExceptT $ do
	globalInit
	F.ReturnCode rc <- liftIO $ F.gnutls_init sPtr end
	when (rc < 0) $ E.throwError $ mapError rc
	when (rc < 0) $ throwE $ mapError rc
	liftIO $ do
		ptr <- F.peek sPtr
		let session = F.Session ptr


@@ 140,7 134,7 @@ newSession transport end = F.alloca $ \sPtr -> runErrorT $ do
		return (Session fp creds)

getSession :: TLS Session
getSession = TLS R.ask
getSession = TLS $ lift R.ask

handshake :: TLS ()
handshake = withSession F.gnutls_handshake >>= checkRC


@@ 154,7 148,7 @@ putBytes = putChunks . BL.toChunks where
		maybeErr <- withSession $ \s -> foldM (putChunk s) Nothing chunks
		case maybeErr of
			Nothing -> return ()
			Just err -> E.throwError $ mapError $ fromIntegral err
			Just err -> TLS $ mapExceptT lift $ throwE $ mapError $ fromIntegral err
	
	putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where
		loop ptr len = do


@@ 182,7 176,7 @@ getBytes count = do
	
	case mbytes of
		Just bytes -> return bytes
		Nothing   -> E.throwError $ mapError $ fromIntegral len
		Nothing   -> TLS $ mapExceptT lift $ throwE $ mapError $ fromIntegral len

checkPending :: TLS Integer
checkPending = withSession $ \s -> do


@@ 245,7 239,7 @@ withSession io = do
	liftIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session

checkRC :: F.ReturnCode -> TLS ()
checkRC (F.ReturnCode x) = when (x < 0) $ E.throwError $ mapError x
checkRC (F.ReturnCode x) = when (x < 0) $ TLS $ mapExceptT lift $ throwE $ mapError x

mapError :: F.CInt -> Error
mapError = Error . toInteger

D lib/Network/Protocol/TLS/GNU/ErrorT.hs => lib/Network/Protocol/TLS/GNU/ErrorT.hs +0 -82
@@ 1,82 0,0 @@
{-# LANGUAGE TypeFamilies #-}

-- Copyright (C) 2010 John Millikin <jmillikin@gmail.com>
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU General Public License as published by
-- the Free Software Foundation, either version 3 of the License, or
-- any later version.
--
-- This program is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-- GNU General Public License for more details.
--
-- You should have received a copy of the GNU General Public License
-- along with this program.  If not, see <http://www.gnu.org/licenses/>.

module Network.Protocol.TLS.GNU.ErrorT
	( ErrorT (..)
	, mapErrorT
	) where

import           Control.Applicative (Applicative, pure, (<*>))
import           Control.Monad (ap,liftM)
import           Control.Monad.Trans (MonadIO, liftIO)
import           Control.Monad.Trans.Class (MonadTrans, lift)
import qualified Control.Monad.Error as E
import           Control.Monad.Error (ErrorType)
import qualified Control.Monad.Reader as R
import           Control.Monad.Reader (EnvType)

-- A custom version of ErrorT, without the 'Error' class restriction.

newtype ErrorT e m a = ErrorT { runErrorT :: m (Either e a) }

instance Functor m => Functor (ErrorT e m) where
	fmap f = ErrorT . fmap (fmap f) . runErrorT

instance (Functor m, Monad m) => Applicative (ErrorT e m) where
	pure a  = ErrorT $ return (Right a)
	f <*> v = ErrorT $ do
		mf <- runErrorT f
		case mf of
			Left  e -> return (Left e)
			Right k -> do
				mv <- runErrorT v
				case mv of
					Left  e -> return (Left e)
					Right x -> return (Right (k x))

instance Monad m => Monad (ErrorT e m) where
	return = ErrorT . return . Right
	(>>=) m k = ErrorT $ do
		x <- runErrorT m
		case x of
			Left l -> return $ Left l
			Right r -> runErrorT $ k r

instance Monad m => E.MonadError (ErrorT e m) where
	type ErrorType (ErrorT e m) = e
	throwError = ErrorT . return . Left
	catchError m h = ErrorT $ do
		x <- runErrorT m
		case x of
			Left l -> runErrorT $ h l
			Right r -> return $ Right r

instance MonadTrans (ErrorT e) where
	lift = ErrorT . liftM Right

instance R.MonadReader m => R.MonadReader (ErrorT e m) where
	type EnvType (ErrorT e m) = EnvType m
	ask = lift R.ask
	local = mapErrorT . R.local

instance MonadIO m => MonadIO (ErrorT e m) where
	liftIO = lift . liftIO

mapErrorT :: (m (Either e a) -> n (Either e' b))
           -> ErrorT e m a
           -> ErrorT e' n b
mapErrorT f m = ErrorT $ f (runErrorT m)