~singpolyma/network-protocol-xmpp

d0f194da32afc3a5fbf63d543721ff59786bdafb — John Millikin 14 years ago 2969f4f
Add error handling hooks to 'Handle' computation signatures, to simplify the migration to a better GNU TLS binding.
M Network/Protocol/XMPP/Client.hs => Network/Protocol/XMPP/Client.hs +8 -6
@@ 35,7 35,8 @@ import qualified Network.Protocol.XMPP.Client.Features as F
import qualified Network.Protocol.XMPP.Handle as H
import qualified Network.Protocol.XMPP.JID as J
import qualified Network.Protocol.XMPP.Monad as M
import Network.Protocol.XMPP.XML (element, qname, readEventsUntil)
import Network.Protocol.XMPP.XML (element, qname)
import Network.Protocol.XMPP.ErrorT
import Network.Protocol.XMPP.Stanza

runClient :: C.Server


@@ 61,9 62,8 @@ runClient server jid username password xmpp = do

newStream :: J.JID -> M.XMPP [F.Feature]
newStream jid = do
	M.Context h _ sax <- M.getContext
	liftIO $ H.hPutBytes h $ C.xmlHeader "jabber:client" jid
	liftIO $ readEventsUntil C.startOfStream h sax
	M.putBytes $ C.xmlHeader "jabber:client" jid
	M.readEvents C.startOfStream
	F.parseFeatures `fmap` M.getTree

tryTLS :: [F.Feature] -> M.XMPP a -> M.XMPP a


@@ 73,8 73,10 @@ tryTLS features m
		M.putTree xmlStartTLS
		M.getTree
		h <- M.getHandle
		tls <- liftIO $ H.startTLS h
		M.restartXMPP (Just tls) m
		eitherTLS <- liftIO $ runErrorT $ H.startTLS h
		case eitherTLS of
			Left err -> throwError $ M.TransportError err
			Right tls -> M.restartXMPP (Just tls) m

authenticationMechanisms :: [F.Feature] -> [ByteString]
authenticationMechanisms = step where

M Network/Protocol/XMPP/Component.hs => Network/Protocol/XMPP/Component.hs +3 -5
@@ 20,7 20,6 @@ module Network.Protocol.XMPP.Component
	) where
import Control.Monad (when)
import Control.Monad.Error (throwError)
import Control.Monad.Trans (liftIO)
import Data.Bits (shiftR, (.&.))
import Data.Char (intToDigit)
import qualified Data.ByteString as B


@@ 38,7 37,7 @@ import qualified Text.XML.LibXML.SAX as SAX
import qualified Network.Protocol.XMPP.Connections as C
import qualified Network.Protocol.XMPP.Handle as H
import qualified Network.Protocol.XMPP.Monad as M
import Network.Protocol.XMPP.XML (element, qname, readEventsUntil)
import Network.Protocol.XMPP.XML (element, qname)
import Network.Protocol.XMPP.JID (JID)

runComponent :: C.Server


@@ 57,9 56,8 @@ runComponent server password xmpp = do

beginStream :: JID -> M.XMPP T.Text
beginStream jid = do
	M.Context h _ sax <- M.getContext
	liftIO $ H.hPutBytes h $ C.xmlHeader "jabber:component:accept" jid
	events <- liftIO $ readEventsUntil C.startOfStream h sax
	M.putBytes $ C.xmlHeader "jabber:component:accept" jid
	events <- M.readEvents C.startOfStream
	case parseStreamID $ last events of
		Nothing -> throwError M.NoComponentStreamID
		Just x -> return x

M Network/Protocol/XMPP/Connections.hs => Network/Protocol/XMPP/Connections.hs +1 -1
@@ 51,7 51,7 @@ xmlHeader ns jid = encodeUtf8 header where
		, " xmlns:stream=\"http://etherx.jabber.org/streams\">"
		]

startOfStream :: Int -> SAX.Event -> Bool
startOfStream :: Integer -> SAX.Event -> Bool
startOfStream depth event = case (depth, event) of
	(1, (SAX.BeginElement elemName _)) ->
		qnameStream == convertQName elemName

M Network/Protocol/XMPP/Handle.hs => Network/Protocol/XMPP/Handle.hs +15 -10
@@ 13,6 13,7 @@
-- You should have received a copy of the GNU General Public License
-- along with this program.  If not, see <http://www.gnu.org/licenses/>.

{-# LANGUAGE OverloadedStrings #-}
module Network.Protocol.XMPP.Handle
	( Handle (..)
	, startTLS


@@ 21,21 22,25 @@ module Network.Protocol.XMPP.Handle
	) where

import Control.Monad (when)
import qualified System.IO as IO
import qualified Control.Monad.Error as E
import Control.Monad.Trans (liftIO)
import qualified Data.ByteString as B
import qualified Data.ByteString.Unsafe as B
import qualified Data.Text as T
import qualified System.IO as IO
import qualified Network.GnuTLS as GnuTLS
import Network.GnuTLS (AttrOp (..))
import Foreign (allocaBytes, plusPtr)
import Foreign.C (peekCAStringLen)
import Network.Protocol.XMPP.ErrorT

data Handle =
	  PlainHandle IO.Handle
	| SecureHandle IO.Handle (GnuTLS.Session GnuTLS.Client)

startTLS :: Handle -> IO Handle
startTLS h@(SecureHandle _ _) = return h
startTLS (PlainHandle h) = do
startTLS :: Handle -> ErrorT T.Text IO Handle
startTLS (SecureHandle _ _) = E.throwError "Can't start TLS on a secure handle"
startTLS (PlainHandle h) = liftIO $ do
	session <- GnuTLS.tlsClient
		[ GnuTLS.handle := h
		, GnuTLS.priorities := [GnuTLS.CrtX509]


@@ 44,9 49,9 @@ startTLS (PlainHandle h) = do
	GnuTLS.handshake session
	return $ SecureHandle h session

hPutBytes :: Handle -> B.ByteString -> IO ()
hPutBytes (PlainHandle h)          bytes = B.hPut h bytes
hPutBytes (SecureHandle _ session) bytes = useLoop where
hPutBytes :: Handle -> B.ByteString -> ErrorT T.Text IO ()
hPutBytes (PlainHandle h)           bytes = liftIO $ B.hPut h bytes
hPutBytes (SecureHandle _ session) bytes = liftIO useLoop where
	useLoop = B.unsafeUseAsCStringLen bytes $ uncurry loop
	loop ptr len = do
		r <- GnuTLS.tlsSend session ptr len


@@ 54,9 59,9 @@ hPutBytes (SecureHandle _ session) bytes = useLoop where
			x | x > 0     -> loop (plusPtr ptr r) x
			  | otherwise -> return ()

hGetChar :: Handle -> IO Char
hGetChar (PlainHandle h) = IO.hGetChar h
hGetChar (SecureHandle h session) = allocaBytes 1 $ \ptr -> do
hGetChar :: Handle -> ErrorT T.Text IO Char
hGetChar (PlainHandle h) = liftIO $ IO.hGetChar h
hGetChar (SecureHandle h session) = liftIO $ allocaBytes 1 $ \ptr -> do
	pending <- GnuTLS.tlsCheckPending session
	when (pending == 0) $ do
		IO.hWaitForInput h (-1)

M Network/Protocol/XMPP/Monad.hs => Network/Protocol/XMPP/Monad.hs +49 -11
@@ 25,28 25,39 @@ module Network.Protocol.XMPP.Monad
	, getHandle
	, getContext
	
	, putTree
	, readEvents
	, getChar
	, getTree
	, getStanza
	
	, putBytes
	, putTree
	, putStanza
	, getStanza
	) where
import Prelude hiding (getChar)
import Control.Monad.Trans (MonadIO, liftIO)
import qualified Control.Monad.Error as E
import qualified Control.Monad.Reader as R
import qualified Data.ByteString.Char8 as B
import Data.Text (Text)
import Text.XML.HXT.DOM.Interface (XmlTree)

import Text.XML.HXT.Arrow ((>>>))
import qualified Text.XML.HXT.Arrow as A
import qualified Text.XML.HXT.DOM.Interface as DOM
import qualified Text.XML.HXT.DOM.XmlNode as XN
import qualified Text.XML.LibXML.SAX as SAX

import Network.Protocol.XMPP.ErrorT
import qualified Network.Protocol.XMPP.Handle as H
import qualified Network.Protocol.XMPP.Stanza as S
import qualified Network.Protocol.XMPP.XML as X

data Error
	= InvalidStanza XmlTree
	= InvalidStanza DOM.XmlTree
	| InvalidBindResult S.ReceivedStanza
	| AuthenticationFailure
	| AuthenticationError Text
	| TransportError Text
	| NoComponentStreamID
	| ComponentHandshakeFailed
	deriving (Show)


@@ 93,19 104,46 @@ getHandle = do
	Context h _ _ <- getContext
	return h

putTree :: XmlTree -> XMPP ()
putTree t = do
liftTLS :: ErrorT Text IO a -> XMPP a
liftTLS io = do
	res <- liftIO $ runErrorT io
	case res of
		Left err -> E.throwError $ TransportError err
		Right x -> return x


putBytes :: B.ByteString -> XMPP ()
putBytes bytes = do
	h <- getHandle
	liftIO $ X.putTree h t
	liftTLS $ H.hPutBytes h bytes

getTree :: XMPP XmlTree
getTree = do
	Context h _ sax <- getContext
	liftIO $ X.getTree h sax
getChar :: XMPP Char
getChar = do
	h <- getHandle
	liftTLS $ H.hGetChar h

putTree :: DOM.XmlTree -> XMPP ()
putTree t = do
	let root = XN.mkRoot [] [t]
	[text] <- liftIO $ A.runX (A.constA root >>> A.writeDocumentToString [
		(A.a_no_xml_pi, "1")
		])
	h <- getHandle
	liftTLS $ H.hPutBytes h $ B.pack text

putStanza :: S.Stanza a => a -> XMPP ()
putStanza = putTree . S.stanzaToTree

readEvents :: (Integer -> SAX.Event -> Bool) -> XMPP [SAX.Event]
readEvents done = do
	Context h _ p <- getContext
	X.readEvents done (liftTLS $ H.hGetChar h) p

getTree :: XMPP DOM.XmlTree
getTree = X.eventsToTree `fmap` readEvents endOfTree where
	endOfTree 0 (SAX.EndElement _) = True
	endOfTree _ _ = False

getStanza :: XMPP S.ReceivedStanza
getStanza = do
	tree <- getTree

M Network/Protocol/XMPP/XML.hs => Network/Protocol/XMPP/XML.hs +27 -42
@@ 14,16 14,16 @@
-- along with this program.  If not, see <http://www.gnu.org/licenses/>.

module Network.Protocol.XMPP.XML
	( getTree
	, putTree
	, readEventsUntil
	( readEvents
	, eventsToTree
	, convertQName
	, element
	, attr
	, qname
	) where
import qualified Network.Protocol.XMPP.Handle as H
import Control.Monad.Trans (MonadIO, liftIO)
import qualified Data.ByteString.Char8 as C8
import qualified Network.Protocol.XMPP.Handle as H

-- XML Parsing
import Text.XML.HXT.Arrow ((>>>))


@@ 32,44 32,29 @@ import qualified Text.XML.HXT.DOM.Interface as DOM
import qualified Text.XML.HXT.DOM.XmlNode as XN
import qualified Text.XML.LibXML.SAX as SAX

getTree :: H.Handle -> SAX.Parser -> IO DOM.XmlTree
getTree h p = eventsToTree `fmap` readEventsUntil finished h p where
	finished 0 (SAX.EndElement _) = True
	finished _ _ = False

putTree :: H.Handle -> DOM.XmlTree -> IO ()
putTree h t = do
	let root = XN.mkRoot [] [t]
	[text] <- A.runX (A.constA root >>> A.writeDocumentToString [
		(A.a_no_xml_pi, "1")
		])
	H.hPutBytes h $ C8.pack text

-------------------------------------------------------------------------------

readEventsUntil :: (Int -> SAX.Event -> Bool) -> H.Handle -> SAX.Parser -> IO [SAX.Event]
readEventsUntil done h parser = readEventsUntil' done 0 [] $ do
	char <- H.hGetChar h
	SAX.parse parser [char] False

readEventsUntil' :: (Int -> SAX.Event -> Bool) -> Int -> [SAX.Event] -> IO [SAX.Event] -> IO [SAX.Event]
readEventsUntil' done depth accum getEvents = do
	events <- getEvents
	let (done', depth', accum') = readEventsStep done events depth accum
	if done'
		then return accum'
		else readEventsUntil' done depth' accum' getEvents

readEventsStep :: (Int -> SAX.Event -> Bool) -> [SAX.Event] -> Int -> [SAX.Event] -> (Bool, Int, [SAX.Event])
readEventsStep _ [] depth accum = (False, depth, accum)
readEventsStep done (e:es) depth accum = let
	depth' = depth + case e of
		(SAX.BeginElement _ _) -> 1
		(SAX.EndElement _) -> (- 1)
		_ -> 0
	accum' = accum ++ [e]
	in if done depth' e then (True, depth', accum')
	else readEventsStep done es depth' accum'
readEvents :: MonadIO m => (Integer -> SAX.Event -> Bool) -> m Char -> SAX.Parser -> m [SAX.Event]
readEvents done getChar parser = readEvents' 0 [] where
	nextEvents = do
		char <- getChar
		liftIO $ SAX.parse parser [char] False
	
	readEvents' depth acc = do
		events <- nextEvents
		let (done', depth', acc') = step events depth acc
		if done'
			then return acc'
			else readEvents' depth' acc'
	
	step []     depth acc = (False, depth, acc)
	step (e:es) depth acc = let
		depth' = depth + case e of
			(SAX.BeginElement _ _) -> 1
			(SAX.EndElement _) -> (- 1)
			_ -> 0
		acc' = e : acc
		in if done depth' e
			then (True, depth', reverse acc')
			else step es depth' acc'

-------------------------------------------------------------------------------
-- For converting incremental XML event lists to HXT trees