From 902c60fb0b27e7bca7fe2a5d7a6123fe6bb4c8b8 Mon Sep 17 00:00:00 2001 From: John Millikin Date: Fri, 7 May 2010 19:08:28 +0000 Subject: [PATCH] Add locking to 'putStanza' and 'getStanza', to allow basic thread-safe access. --- Network/Protocol/XMPP/Monad.hs | 38 ++++++++++++++++++++++++---------- network-protocol-xmpp.cabal | 2 +- 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/Network/Protocol/XMPP/Monad.hs b/Network/Protocol/XMPP/Monad.hs index 15cb5eb..0b17fe7 100644 --- a/Network/Protocol/XMPP/Monad.hs +++ b/Network/Protocol/XMPP/Monad.hs @@ -35,6 +35,7 @@ module Network.Protocol.XMPP.Monad , putStanza ) where import qualified Control.Applicative as A +import qualified Control.Concurrent.MVar as M import Control.Monad (ap) import Control.Monad.Fix (MonadFix, mfix) import Control.Monad.Trans (MonadIO, liftIO) @@ -73,7 +74,13 @@ data Error | NoComponentStreamID deriving (Show) -data Session = Session H.Handle Text SAX.Parser +data Session = Session + { sessionHandle :: H.Handle + , sessionNamespace :: Text + , sessionParser :: SAX.Parser + , sessionReadLock :: M.MVar () + , sessionWriteLock :: M.MVar () + } newtype XMPP a = XMPP { unXMPP :: ErrorT Error (R.ReaderT Session IO) a } @@ -105,22 +112,31 @@ runXMPP s xmpp = R.runReaderT (runErrorT (unXMPP xmpp)) s startXMPP :: H.Handle -> Text -> XMPP a -> IO (Either Error a) startXMPP h ns xmpp = do sax <- SAX.newParser - runXMPP (Session h ns sax) xmpp + readLock <- M.newMVar () + writeLock <- M.newMVar () + runXMPP (Session h ns sax readLock writeLock) xmpp restartXMPP :: Maybe H.Handle -> XMPP a -> XMPP a restartXMPP newH xmpp = do - Session oldH ns _ <- getSession + Session oldH ns _ readLock writeLock <- getSession sax <- liftIO SAX.newParser - let s = Session (maybe oldH id newH) ns sax + let s = Session (maybe oldH id newH) ns sax readLock writeLock XMPP $ R.local (const s) (unXMPP xmpp) +withLock :: (Session -> M.MVar ()) -> XMPP a -> XMPP a +withLock getLock xmpp = do + s <- getSession + let mvar = getLock s + res <- liftIO $ M.withMVar mvar $ \_ -> runXMPP s xmpp + case res of + Left err -> E.throwError err + Right x -> return x + getSession :: XMPP Session getSession = XMPP R.ask getHandle :: XMPP H.Handle -getHandle = do - Session h _ _ <- getSession - return h +getHandle = fmap sessionHandle getSession liftTLS :: ErrorT Text IO a -> XMPP a liftTLS io = do @@ -138,12 +154,12 @@ putElement :: X.Element -> XMPP () putElement = putBytes . encodeUtf8 . X.serialiseElement putStanza :: S.Stanza a => a -> XMPP () -putStanza = putElement . S.stanzaToElement +putStanza = withLock sessionWriteLock . putElement . S.stanzaToElement readEvents :: (Integer -> SAX.Event -> Bool) -> XMPP [SAX.Event] readEvents done = xmpp where xmpp = do - Session h _ p <- getSession + Session h _ p _ _ <- getSession let nextEvents = do -- TODO: read in larger increments bytes <- liftTLS $ H.hGetBytes h 1 @@ -170,9 +186,9 @@ getElement = xmpp where endOfTree _ _ = False getStanza :: XMPP S.ReceivedStanza -getStanza = do +getStanza = withLock sessionReadLock $ do elemt <- getElement - Session _ ns _ <- getSession + Session _ ns _ _ _ <- getSession case S.elementToStanza ns elemt of Just x -> return x Nothing -> E.throwError $ InvalidStanza elemt diff --git a/network-protocol-xmpp.cabal b/network-protocol-xmpp.cabal index 45ac512..0f642b6 100644 --- a/network-protocol-xmpp.cabal +++ b/network-protocol-xmpp.cabal @@ -1,5 +1,5 @@ name: network-protocol-xmpp -version: 0.3 +version: 0.3.1 synopsis: Client <-> Server communication over XMPP license: GPL-3 license-file: License.txt -- 2.45.2