~jojo/Carth

ref: a07d1ce73f8cb27624ea7647d1a339c2af537ca5 Carth/src/Parser.hs -rw-r--r-- 3.3 KiB
a07d1ce7JoJo Infer: Minor refactor 1 year, 3 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
{-# LANGUAGE FlexibleContexts, LambdaCase, TupleSections, DataKinds
           , GeneralizedNewtypeDeriving #-}

module Parser (Parser, St (..), runParser, token, end, lookAhead, try, getSrcPos) where

import Control.Applicative hiding (many, some)
import Control.Monad
import Control.Monad.State.Strict
import Control.Monad.Except
import Data.Functor
import Data.List
import Data.Set (Set)
import qualified Data.Set as Set

import Misc
import SrcPos
import Lexd

data Err = Err { errLength :: Word, errPos :: SrcPos, errExpecteds :: Set String }
    deriving (Show, Eq)

-- TODO: Semigroup instance for the error type should select the one with the longest
--       match. We need to keep track of how long the match is.
--
--       If two matches are of the same / no length, combine the sets of "expected"s of
--       both. So it's like "expected keyword extern or keyword data".
instance Semigroup Err where
    (<>) e1 e2
        | e2 == mempty = e1
        | errLength e1 > errLength e2 = e1
        | errLength e1 < errLength e2 = e2
        | otherwise = Err (errLength e2)
                          (errPos e2)
                          (Set.union (errExpecteds e1) (errExpecteds e2))

instance Monoid Err where
    mempty = Err 0 (SrcPos "<dummy>" 0 0) Set.empty

data St = St { stCount :: Word, stOuterPos :: SrcPos, stInput :: [TokenTree] }

newtype Parser a = Parser (StateT St (Except Err) a)
    deriving (Functor, Applicative, MonadPlus, Monad, MonadError Err, MonadState St)

instance Alternative Parser where
    empty = Parser (throwError mempty)
    (<|>) ma mb = do
        n <- gets stCount
        catchError ma $ \e -> if errLength e > n
            then throwError e
            else catchError mb (throwError . (e <>))

runParser :: Parser a -> [TokenTree] -> Except (SrcPos, String) a
runParser (Parser ma) tts =
    let noPos = ice "read SrcPos in parser state at top level"
        initSt = St 0 noPos tts
        formatExpecteds es = case Set.toList es of
            [] -> ice "empty list of expecteds in formatExpecteds"
            [e] -> "Expected " ++ e
            es' -> "Expected one of: " ++ intercalate ", " es'
    in  withExcept (\(Err _ pos es) -> (pos, formatExpecteds es)) (evalStateT ma initSt)

token :: String -> (SrcPos -> TokenTree' -> Maybe a) -> Parser a
token exp f = do
    St n outerPos xs <- get
    case xs of
        WithPos innerPos x : xs' -> do
            a <- mexcept (Err n innerPos (Set.singleton exp)) (f innerPos x)
            modify (\st -> st { stCount = n + 1, stInput = xs' })
            pure a
        [] -> throwError
            $ Err n outerPos (Set.singleton "continuation of token sequence")

-- | Succeeds only when current input sequence (may be nested in sexpr) is empty
end :: Parser ()
end = get >>= \(St n _ inp) -> case inp of
    [] -> pure ()
    WithPos p _ : _ ->
        throwError (Err n p (Set.singleton "end of (nested) token sequence"))

mexcept :: MonadError e m => e -> Maybe a -> m a
mexcept e = maybe (throwError e) pure

lookAhead :: Parser a -> Parser a
lookAhead pa = get >>= \s -> pa >>= \a -> put s $> a

try :: Parser a -> Parser a
try ma = do
    s <- get
    catchError ma $ \e -> do
        put s
        throwError (e { errLength = stCount s })

getSrcPos :: Parser SrcPos
getSrcPos = lookAhead (token "token" (Just .* const))