~jojo/Carth

5807eef4e8fe28bbca5f86a4d9bc7453edf59a9b — JoJo 2 years ago f086b9c
Integrate Match with rest of compiler

Until now, the Match module was just kind of for "testing". It wasn't
actually used in the rest of the compiler. This commit replaces my
earlier, quite bad implementation of pattern matching to decision
trees with Sestofts one.

One major change made along the way was to how sub-matchees are
accessed in the code-generator and interpreter. Previously, the
decision tree was assumed to be structured such that we could simply
extract sub-matchees and put them on a stack. As far as I understand
it, that way is incompatible with Sestofts trees, as they can "skip"
steps that are known to be irrefutable in context. As such, I have
chosen to augment the `Access` type with some additional information,
such that the value of a sub-matchee can simply be generated from the
Access-path as needed, with memoization in the Selections module. See
`genSelect`.
M package.yaml => package.yaml +0 -1
@@ 28,7 28,6 @@ dependencies:
- composition
- mtl
- lens
- utility-ht
- llvm-hs-pure
- llvm-hs
- llvm-hs-pretty

M src/AnnotAst.hs => src/AnnotAst.hs +10 -4
@@ 12,6 12,8 @@ module AnnotAst
    , TypedVar(..)
    , Const(..)
    , VariantIx
    , Access(..)
    , VarBindings
    , DecisionTree(..)
    , Ction
    , Expr(..)


@@ 32,11 34,15 @@ data TypedVar = TypedVar String Type

type VariantIx = Word64

data Access = Obj | As Access [Type] | Sel Word32 Access
    deriving (Show, Eq, Ord)

type VarBindings = Map TypedVar Access

data DecisionTree
    = DecisionTree (Map VariantIx ([Type], DecisionTree))
                   (Maybe (TypedVar, DecisionTree))
    | DecisionLeaf Expr
    deriving (Show)
    = DLeaf (VarBindings, Expr)
    | DSwitch Access (Map VariantIx DecisionTree) DecisionTree
    deriving Show

type Ction = (VariantIx, (String, [Type]), [Expr])


M src/Check.hs => src/Check.hs +39 -121
@@ 3,20 3,18 @@

module Check (typecheck) where

import Prelude hiding (span)
import Control.Lens
    ((<<+=), assign, makeLenses, over, use, view, views, locally, mapped)
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State.Strict
import Control.Arrow ((>>>))
import Data.Either.Combinators
import Data.Bifunctor
import Data.Foldable
import Data.Graph (SCC(..), flattenSCC, stronglyConnComp)
import qualified Data.Map.Strict as Map
import Data.Map.Strict (Map)
import Data.Maybe
import Data.Composition
import qualified Data.Set as Set
import Data.Set (Set)



@@ 29,14 27,17 @@ import qualified Ast
import Ast (Id, idstr, scmBody)
import TypeErr
import AnnotAst
import Match


data Env = Env
    { _envDefs :: Map String Scheme
    -- | Maps a constructor to its variant index in the type definition it
    --   constructs, the signature/left-hand-side of the type definition, and
    --   the types of its parameters
    , _envCtors :: Map String (VariantIx, (String, [TVar]), [Type])
    , _envTypeDefs :: Map String ([TVar], [[Type]])
    --   constructs, the signature/left-hand-side of the type definition, the
    --   types of its parameters, and the span (number of constructors) of the
    --   datatype
    , _envCtors :: Map String (VariantIx, (String, [TVar]), [Type], Span)
    , _envTypeDefs :: Map String ([TVar], [(String, [Type])])
    }
makeLenses ''Env



@@ 111,13 112,14 @@ inferProgram (Ast.Program defs tdefs) = do
    let expectedMainType = TFun (TPrim TUnit) (TPrim TUnit)
    unify (Expected expectedMainType) (Found mainPos mainT)
    let defs'' = Map.delete "main" defs'
    pure (Program main (Defs defs'') tdefs')
    let tdefs'' = fmap (second (map snd)) tdefs'
    pure (Program main (Defs defs'') (tdefs''))

checkTypeDefs
    :: [Ast.TypeDef]
    -> Infer
           ( Map String ([TVar], [[Type]])
           , Map String (VariantIx, (String, [TVar]), [Type])
           ( Map String ([TVar], [(String, [Type])])
           , Map String (VariantIx, (String, [TVar]), [Type], Span)
           )
checkTypeDefs =
    (fmap (second (fmap snd)) .)


@@ 133,16 135,18 @@ checkTypeDefs =
checkTypeDef
    :: Ast.TypeDef
    -> Infer
           ( (String, ([TVar], [[Type]]))
           , Map String (Id, (VariantIx, (String, [TVar]), [Type]))
           ( (String, ([TVar], [(String, [Type])]))
           , Map String (Id, (VariantIx, (String, [TVar]), [Type], Span))
           )
checkTypeDef (Ast.TypeDef (WithPos _ x) ps (Ast.ConstructorDefs cs)) = do
    let ps' = map TVExplicit ps
    let cs' = map snd cs
    let cs' = map (first idstr) cs
    let cSpan = length cs
    cs''' <- foldM
        (\cs'' (i, (cx, cps)) -> if Map.member (idstr cx) cs''
            then throwError (ConflictingCtorDef cx)
            else pure (Map.insert (idstr cx) (cx, (i, (x, ps'), cps)) cs'')
            else pure
                (Map.insert (idstr cx) (cx, (i, (x, ps'), cps, cSpan)) cs'')
        )
        Map.empty
        (zip [0 ..] cs)


@@ 200,7 204,7 @@ checkUserSchemes scms = forM_ scms $ \(WithPos p s1@(Forall _ t)) ->
        >>= \s2 -> when (s1 /= s2) (throwError (InvalidUserTypeSig p s1 s2))

infer :: Ast.Expr -> Infer (Type, Expr)
infer = unpos >>> \case
infer (WithPos pos e) = case e of
    Ast.Lit l -> pure (litType l, Lit l)
    Ast.Var x -> fmap (\t -> (t, Var (TypedVar (idstr x) t))) (lookupEnv x)
    Ast.App f a -> do


@@ 234,22 238,29 @@ infer = unpos >>> \case
    Ast.Match matchee cases -> do
        (tmatchee, matchee') <- infer matchee
        (tbody, cases') <- inferCases (Expected tmatchee) cases
        dt <- toDecisionTree tmatchee cases'
        dt <- toDecisionTree' pos tmatchee cases'
        pure (tbody, Match matchee' dt tbody)
    Ast.FunMatch cases -> do
        tpat <- fresh
        (tbody, cases') <- inferCases (Expected tpat) cases
        dt <- toDecisionTree tpat cases'
        dt <- toDecisionTree' pos tpat cases'
        let t = TFun tpat tbody
        x <- freshVar
        let e = Fun (x, tpat) (Match (Var (TypedVar x tpat)) dt tbody, tbody)
        pure (t, e)
    Ast.Ctor c -> inferExprConstructor c

data Pat
    = PConstruction VariantIx [Type] [Pat]
    | PVar String
    deriving (Show)
toDecisionTree' :: SrcPos -> Type -> [(SrcPos, Pat, Expr)] -> Infer DecisionTree
toDecisionTree' pos tpat cases = do
    -- TODO: Could we do this differently, more efficiently?
    --
    -- Match needs to be able to match on the pattern types to generate proper
    -- error messages for inexhaustive patterns, so apply substitutions.
    s <- use substs
    let tpat' = subst s tpat
    let cases' = map (\(pos, p, e) -> (pos, substPat s p, e)) cases
    mTypeDefs <- views envTypeDefs (fmap (map fst . snd))
    lift (lift (toDecisionTree mTypeDefs pos tpat' cases'))

-- | All the patterns must be of the same types, and all the bodies must be of
--   the same type.


@@ 279,12 290,13 @@ inferPat = \case
    Ast.PVar x -> do
        tv <- fresh'
        let tv' = TVar tv
        pure (tv', PVar (idstr x), Map.singleton x (Forall Set.empty tv'))
        let x' = TypedVar (idstr x) tv'
        pure (tv', PVar x', Map.singleton x (Forall Set.empty tv'))

inferPatConstruction
    :: SrcPos -> Id -> [Ast.Pat] -> Infer (Type, Pat, Map Id Scheme)
inferPatConstruction pos c cArgs = do
    (variantIx, tdefLhs, cParams) <- lookupEnvConstructor c
    (variantIx, tdefLhs, cParams, cSpan) <- lookupEnvConstructor c
    let arity = length cParams
    let nArgs = length cArgs
    unless (arity == nArgs) (throwError (CtorArityMismatch pos c arity nArgs))


@@ 294,7 306,8 @@ inferPatConstruction pos c cArgs = do
    cArgsVars' <- nonconflictingPatVarDefs cArgsVars
    forM_ (zip3 cParams' cArgTs cArgs) $ \(cParamT, cArgT, cArg) ->
        unify (Expected cParamT) (Found (getPos cArg) cArgT)
    pure (t, PConstruction variantIx cArgTs cArgs', cArgsVars')
    let con = Con { variant = variantIx, span = cSpan, argTs = cArgTs }
    pure (t, PCon con cArgs', cArgsVars')

nonconflictingPatVarDefs :: [Map Id Scheme] -> Infer (Map Id Scheme)
nonconflictingPatVarDefs = flip foldM Map.empty $ \acc ks ->


@@ 302,104 315,9 @@ nonconflictingPatVarDefs = flip foldM Map.empty $ \acc ks ->
        Just (WithPos pos v) -> throwError (ConflictingPatVarDefs pos v)
        Nothing -> pure (Map.union acc ks)

-- TODO: Check for exhaustiveness
-- | Build decision tree that matches out -> in, left -> right
--
--   For each variant/constructor, there is a node in the decision tree. When
--   picking a ctor in a column to create a sub-tree, remove all rows in that
--   column not beginning with that ctor, then splice the members of the variant
--   into the matrix.
toDecisionTree :: Type -> [(SrcPos, Pat, Expr)] -> Infer DecisionTree
toDecisionTree tpat cs =
    toDecisionTreeRows tpat [] (map (\(pos, p, e) -> (pos, p, [], e)) cs)

toDecisionTreeRows
    :: Type -> [Type] -> [(SrcPos, Pat, [Pat], Expr)] -> Infer DecisionTree
toDecisionTreeRows tpat tpats cases = do
    varName <- freshVar
    (ctorCases, varCases) <- foldlM
        (toDecisionTreeRow tpats varName)
        (Map.empty, [])
        cases
    buildDecisionTree ctorCases (TypedVar varName tpat) (tpats, varCases)

-- TODO: Generalize variable patterns to wildcard patterns, and add support for
--       binding a variable to an arbitrary pattern.  E.g. the case `a -> foo a`
--       is equivalent to `a@_ -> foo a`, and e.g. `a@Foo -> bar a` should be
--       allowed.
toDecisionTreeRow
    :: [Type]
    -> String
    -> ( Map VariantIx ([Type], [(SrcPos, [Pat], Expr)])
       , [(SrcPos, [Pat], Expr)]
       )
    -> (SrcPos, Pat, [Pat], Expr)
    -> Infer
           ( Map VariantIx ([Type], [(SrcPos, [Pat], Expr)])
           , [(SrcPos, [Pat], Expr)]
           )
toDecisionTreeRow ts varName (ctorCases, varCases) (pos, col, cols, body) =
    case col of
        PConstruction ctor cts ps ->
            -- Checks if constructor pattern is made redundant by earlier
            -- variable pattern
            if isRedundant ps (map (\(_, x, _) -> x) varCases)
                then throwError (RedundantCase pos)
                else
                    let
                        row' = (pos, ps ++ cols, body)
                        ts' = cts ++ ts
                        ctorCases' = insertWith'
                            (second (row' :))
                            ctor
                            (ts', [row'])
                            ctorCases
                    in pure (ctorCases', varCases)
        PVar x ->
            let
                body' = substVExpr (x, varName) body
                varCases' = (pos, cols, body') : varCases
            in pure (ctorCases, varCases')

isRedundant :: [Pat] -> [[Pat]] -> Bool
isRedundant ps = any (isRedundant' ps)

isRedundant' :: [Pat] -> [Pat] -> Bool
isRedundant' = all isRedundant'' .* zip

isRedundant'' :: (Pat, Pat) -> Bool
isRedundant'' = \case
    (PConstruction _ _ ps, PConstruction _ _ qs) -> isRedundant' ps qs
    (PConstruction _ _ _, PVar _) -> False
    (PVar _, _) -> True

buildDecisionTree
    :: Map VariantIx ([Type], [(SrcPos, [Pat], Expr)])
    -> TypedVar
    -> ([Type], [(SrcPos, [Pat], Expr)])
    -> Infer DecisionTree
buildDecisionTree ctorCases varLhs varCases@(_, varCases') = do
    ctorCases' <- forM ctorCases
        $ \cs@(ts, _) -> fmap (ts, ) (toDecisionTreeRows' cs)
    varDecisionTree <- if null varCases'
        then pure Nothing
        else fmap (Just . (varLhs, )) (toDecisionTreeRows' varCases)
    pure (DecisionTree ctorCases' varDecisionTree)

toDecisionTreeRows' :: ([Type], [(SrcPos, [Pat], Expr)]) -> Infer DecisionTree
toDecisionTreeRows' = \case
    ([], [(_, [], body)]) -> pure (DecisionLeaf body)
    -- If constructor or variable pattern is redundant by duplication
    ([], _ : cs) ->
        let (pos, _, _) = last cs in throwError (RedundantCase pos)
    (t : ts, cs) -> toDecisionTreeRows t ts $ flip map cs $ \case
        (_, [], _) -> ice "ps empty in toDecisionTreeRows'"
        (pos, p : ps, b) -> (pos, p, ps, b)
    x -> ice $ "unexpected pattern in toDecisionTreeRows': " ++ show x

inferExprConstructor :: Id -> Infer (Type, Expr)
inferExprConstructor c = do
    (variantIx, tdefLhs, cParams) <- lookupEnvConstructor c
    (variantIx, tdefLhs, cParams, _) <- lookupEnvConstructor c
    (tdefInst, cParams') <- instantiateConstructorOfTypeDef tdefLhs cParams
    cParams'' <- mapM (\t -> fmap (, t) freshVar) cParams'
    let cArgs = map (Var . uncurry TypedVar) cParams''


@@ 416,7 334,7 @@ instantiateConstructorOfTypeDef (tName, tParams) cParams = do
    let cParams' = map (subst (Map.fromList (zip tParams tVars))) cParams
    pure ((tName, tVars), cParams')

lookupEnvConstructor :: Id -> Infer (VariantIx, (String, [TVar]), [Type])
lookupEnvConstructor :: Id -> Infer (VariantIx, (String, [TVar]), [Type], Span)
lookupEnvConstructor (WithPos pos cx) =
    views envCtors (Map.lookup cx)
        >>= maybe (throwError (UndefCtor pos cx)) pure

M src/Codegen.hs => src/Codegen.hs +48 -56
@@ 39,9 39,7 @@ import Data.Map.Strict (Map)
import qualified Data.Set as Set
import Data.Word
import Data.Foldable
import Data.Functor
import Data.List
import Data.Maybe
import Data.Composition
import Control.Applicative
import Control.Lens


@@ 51,6 49,7 @@ import Misc
import FreeVars
import qualified MonoAst
import MonoAst hiding (Type, Const)
import Selections


type FFIType = Foreign.Ptr.Ptr LLPtrHierarchy.Type


@@ 171,7 170,6 @@ genBuiltins = map
        (mkName "malloc")
        [parameter (mkName "size") i64]
        (LLType.ptr typeUnit)
    , simpleFunc' (mkName "abort") [] typeUnit [LLFnAttr.NoReturn]
    , simpleFunc (mkName "printInt") [parameter (mkName "n") i64] typeUnit
    ]



@@ 349,60 347,60 @@ genDef (n, t, e) = genVar n t (genExpr e)
genMatch :: Expr -> DecisionTree -> Type -> Gen Operand
genMatch m dt tbody = do
    m' <- genExpr m
    genDecisionTree [m'] tbody dt

-- | During eval of decision trees, put sub-matchees on a stack, and they will
--   be popped as we go out -> in, left -> right. Stack starts with matchee.
genDecisionTree :: [Operand] -> Type -> DecisionTree -> Gen Operand
genDecisionTree ms tbody = \case
    DecisionTree cs vdt -> if Map.null cs
        then genVdt vdt
        else do
            let (variantIxs, variantDts) = unzip (Map.toAscList cs)
            variantLs <- mapM
                (newName . (++ "_") . ("variant_" ++) . show)
                variantIxs
            let dests = zip (map litU64 variantIxs) variantLs
            defaultL <- newName "default"
            nextL <- newName "next"
            let (m, ms') = fromJust (uncons ms)
            mVariantIx <- emitReg'
                "found_variant_ix"
                (extractvalueFromNamed m i64 [0])
            commitToNewBlock (switch mVariantIx defaultL dests) defaultL
            v <- genVdt vdt
            let genCase l dt = do
                    commitToNewBlock (br nextL) l
                    genDecisionTree' m ms' tbody dt
            vs <- zipWithM genCase variantLs variantDts
            commitToNewBlock (br nextL) nextL
            emitAnon (phi (zip (v : vs) (defaultL : variantLs)))
    DecisionLeaf b -> genExpr b
  where
    genVdt = \case
        Just (tv, dt) ->
            withLocal tv (head ms) (genDecisionTree (tail ms) tbody dt)
        -- If we fell through the last case, the pattern was nonexhaustive
        -- and we're in a failure state. Only thing to do now is panic!
        Nothing -> genAbort $> undef tbody

genDecisionTree'
    :: Operand
    -> [Operand]
    genDecisionTree tbody dt (newSelections m')

genDecisionTree :: Type -> DecisionTree -> Selections Operand -> Gen Operand
genDecisionTree tbody = \case
    MonoAst.DSwitch selector cs def -> genDecisionSwitch selector cs def tbody
    MonoAst.DLeaf l -> genDecisionLeaf l

genDecisionSwitch
    :: MonoAst.Access
    -> Map VariantIx DecisionTree
    -> DecisionTree
    -> Type
    -> ([MonoAst.Type], DecisionTree)
    -> Selections Operand
    -> Gen Operand
genDecisionTree' matchee stack tbody (ts, dt) = do
genDecisionSwitch selector cs def tbody selections = do
    let (variantIxs, variantDts) = unzip (Map.toAscList cs)
    variantLs <- mapM (newName . (++ "_") . ("variant_" ++) . show) variantIxs
    let dests = zip (map litU64 variantIxs) variantLs
    defaultL <- newName "default"
    nextL <- newName "next"
    (m, selections') <- genSelect selector selections
    mVariantIx <- emitReg' "found_variant_ix" (extractvalueFromNamed m i64 [0])
    commitToNewBlock (switch mVariantIx defaultL dests) defaultL
    v <- genDecisionTree tbody def selections'
    let genCase l dt = do
            commitToNewBlock (br nextL) l
            genDecisionTree tbody dt selections'
    vs <- zipWithM genCase variantLs variantDts
    commitToNewBlock (br nextL) nextL
    emitAnon (phi (zip (v : vs) (defaultL : variantLs)))

genDecisionLeaf
    :: (MonoAst.VarBindings, Expr) -> Selections Operand -> Gen Operand
genDecisionLeaf (bs, e) selections =
    flip withLocals (genExpr e) =<< genSelectVarBindings selections bs

genSelect :: Access -> Selections Operand -> Gen (Operand, Selections Operand)
genSelect = select genAs genSub

genSelectVarBindings
    :: Selections Operand -> VarBindings -> Gen [(TypedVar, Operand)]
genSelectVarBindings = selectVarBindings genAs genSub

genAs :: [MonoAst.Type] -> Operand -> Gen Operand
genAs ts matchee = do
    let tvariant = toLlvmVariantType ts
    let tgeneric = typeOf matchee
    pGeneric <- emitReg' "ction_ptr_generic" (alloca tgeneric)
    emit (store matchee pGeneric)
    p <- emitReg' "ction_ptr" (bitcast pGeneric (LLType.ptr tvariant))
    matchee' <- emitReg' "ction" (load p)
    subs <- mapM
        (emitReg' "submatchee" . extractvalue matchee' . pure)
        (take (length ts) [1 ..])
    genDecisionTree (subs ++ stack) tbody dt
    emitReg' "ction" (load p)

genSub :: Word32 -> Operand -> Gen Operand
genSub i matchee = emitReg' "submatchee" (extractvalue matchee (pure (i + 1)))

genCtion :: MonoAst.Ction -> Gen Operand
genCtion (i, tdef, as) = do


@@ 455,9 453,6 @@ genBoxGeneric x = do
genMalloc :: Operand -> Gen Operand
genMalloc size = emitAnon (callExtern "malloc" (LLType.ptr typeUnit) [size])

genAbort :: Gen ()
genAbort = emit (callExtern' "abort" [])

semiExecRetGen :: Gen Operand -> Gen' (Type, Out)
semiExecRetGen gx = runWriterT $ do
    x <- gx


@@ 618,9 613,6 @@ newName'' s =
callExtern :: String -> Type -> [Operand] -> FunInstruction
callExtern f rt as = WithRetType (callExtern'' f rt as) rt

callExtern' :: String -> [Operand] -> Instruction
callExtern' f as = callExtern'' f typeUnit as

callExtern'' :: String -> Type -> [Operand] -> Instruction
callExtern'' f rt as = Call
    { tailCallKind = Just Tail

M src/Interp.hs => src/Interp.hs +48 -28
@@ 4,15 4,15 @@ module Interp (interpret) where

import Control.Applicative (liftA3)
import Control.Monad.Reader
import Data.Bool.HT
import Data.Functor
import Data.Map.Lazy (Map)
import qualified Data.Map.Lazy as Map
import Data.Maybe
import Data.List
import Data.Word

import Misc
import MonoAst
import Selections

data Val
    = VConst Const


@@ 24,6 24,13 @@ type Env = Map TypedVar Val

type Eval = ReaderT Env IO


instance Show Val where
    show = \case
        VConst c -> "VConst " ++ show c ++ ""
        VFun _ -> "VFun"
        VConstruction c xs -> "VConstruction " ++ show c ++ " " ++ show xs

interpret :: Program -> IO ()
interpret p = runEval (evalProgram p)



@@ 67,7 74,7 @@ eval = \case
    Let defs body -> evalLet defs body
    Match e dt _ -> do
        v <- eval e
        evalDecisionTree [v] dt
        evalDecisionTree dt (newSelections v)
    Ction (i, _, as) -> fmap (VConstruction i) (mapM eval as)

evalApp :: Expr -> Expr -> Eval Val


@@ 80,32 87,45 @@ evalLet defs body = do
    defs' <- evalDefs defs
    withLocals defs' (eval body)

-- | Out to in, left to right.
evalDecisionTree :: [Val] -> DecisionTree -> Eval Val
evalDecisionTree stack = \case
    DecisionTree cs default' -> do
evalDecisionTree :: DecisionTree -> Selections Val -> Eval Val
evalDecisionTree = \case
    DSwitch selector cs def -> evalDecisionSwitch selector cs def
    DLeaf l -> evalDecisionLeaf l

evalDecisionSwitch
    :: Access
    -> Map VariantIx DecisionTree
    -> DecisionTree
    -> Selections Val
    -> Eval Val
evalDecisionSwitch selector cs def selections = do
    (m, selections') <- evalSelect selector selections
    case m of
        VConstruction ctor _ ->
            evalDecisionTree (fromMaybe def (Map.lookup ctor cs)) selections'
        _ -> ice "not VConstruction in evalDecisionSwitch"

evalDecisionLeaf :: (VarBindings, Expr) -> Selections Val -> Eval Val
evalDecisionLeaf (bs, e) selections = flip withLocals (eval e)
    =<< fmap Map.fromList (evalSelectVarBindings selections bs)

evalSelect :: Access -> Selections Val -> Eval (Val, Selections Val)
evalSelect = select evalAs evalSub

evalSelectVarBindings :: Selections Val -> VarBindings -> Eval [(TypedVar, Val)]
evalSelectVarBindings = selectVarBindings evalAs evalSub

evalAs :: [MonoAst.Type] -> Val -> Eval Val
evalAs _ = pure

evalSub :: Word32 -> Val -> Eval Val
evalSub i = \case
    v@(VConstruction _ xs) ->
        let
            (m, ms) = fromMaybe
                (ice "Stack is empty in evalDecisionTree")
                (uncons stack)
        evalDecisionTree' m ms cs >>= \case
            Just v -> pure v
            Nothing -> case default' of
                Just (tv, d) -> withLocal tv m (evalDecisionTree ms d)
                Nothing ->
                    ice "default' is Nothing after fail in evalDecisionTree"
    DecisionLeaf e -> eval e

evalDecisionTree'
    :: Val
    -> [Val]
    -> Map VariantIx (VariantTypes, DecisionTree)
    -> Eval (Maybe Val)
evalDecisionTree' m ms cs = case m of
    VConstruction ctor xs -> case Map.lookup ctor cs of
        Just (_, d) -> fmap Just (evalDecisionTree (xs ++ ms) d)
        Nothing -> pure Nothing
    _ -> pure Nothing
            i' = fromIntegral i
            msg = "i >= length xs in evalSub: " ++ (show i ++ ", " ++ show v)
        in pure (if i' < length xs then xs !! i' else ice msg)
    _ -> ice "evalSub of non VConstruction"

lookupEnv :: (String, Type) -> Eval Val
lookupEnv (x, t) = fmap

M src/Match.hs => src/Match.hs +56 -52
@@ 1,10 1,10 @@
{-# LANGUAGE LambdaCase, TemplateHaskell #-}
{-# LANGUAGE LambdaCase, TemplateHaskell, TupleSections #-}

-- | Implementation of the algorithm described in /ML pattern match compilation
--   and partial evaluation/ by Peter Sestoft. Close to 1:1, and includes the
--   additional checks for exhaustiveness and redundancy described in section
--   7.4.
module Match where
module Match (toDecisionTree, Span, Con(..), Pat(..), MTypeDefs) where

import Prelude hiding (span)
import qualified Data.Set as Set


@@ 14,10 14,10 @@ import Data.Map (Map)
import Data.Maybe
import Data.List (delete)
import Data.Functor
import Control.Applicative
import Control.Monad.Reader
import Control.Monad.State.Strict
import Control.Monad.Except
import Data.Word
import Control.Lens (makeLenses, view, views)

import Misc hiding (augment)


@@ 26,11 26,17 @@ import TypeErr
import AnnotAst


data Con = Con { variant :: VariantIx, arity :: Int, span :: Int }
type Span = Int

data Con = Con
    { variant :: VariantIx
    , span :: Span
    , argTs :: [Type]
    }
    deriving Show

data Pat
    = PVar String
    = PVar TypedVar
    | PCon Con [Pat]
    deriving Show



@@ 43,21 49,18 @@ type Ctx = [(Con, [Descr])]

type Work = [([Pat], [Access], [Descr])]

data Access = Obj | Sel Int Access
    deriving (Show, Eq)

data DecisionDag
    = Success ([(String, Access)], Expr)
    | IfEq Access Con DecisionDag DecisionDag
data DecisionTree'
    = Success (VarBindings, Expr)
    | IfEq Access Con DecisionTree' DecisionTree'
    deriving Show

type Rhs = (SrcPos, [(String, Access)], Expr)
type Rhs = (SrcPos, VarBindings, Expr)

type TypeDefs' = Map String [(String, [Type])]
type MTypeDefs = Map String [String]

type RedundantCases = [SrcPos]

data Env = Env { _tdefs :: TypeDefs', _tpat :: Type, _exprPos :: SrcPos }
data Env = Env { _tdefs :: MTypeDefs, _tpat :: Type, _exprPos :: SrcPos }
makeLenses ''Env

type Match = ReaderT Env (StateT RedundantCases (Except TypeErr))


@@ 70,28 73,28 @@ instance Ord Con where
    compare (Con c1 _ _) (Con c2 _ _) = compare c1 c2


compile
    :: TypeDefs'
toDecisionTree
    :: MTypeDefs
    -> SrcPos
    -> Type
    -> [(SrcPos, Pat, Expr)]
    -> Either TypeErr DecisionDag'
compile tds exprPos tpat cases =
    -> Except TypeErr DecisionTree
toDecisionTree tds exprPos tpat cases =
    let
        rules = map (\(pos, p, e) -> (p, (pos, [], e))) cases
        rules = map (\(pos, p, e) -> (p, (pos, Map.empty, e))) cases
        redundantCases = map (\(pos, _, _) -> pos) cases
    in runExcept $ do
    in do
        let env = Env { _tdefs = tds, _tpat = tpat, _exprPos = exprPos }
        (d, redundantCases') <- runStateT
            (runReaderT (compile' rules) env)
            (runReaderT (compile rules) env)
            redundantCases
        forM_ redundantCases' $ throwError . RedundantCase
        pure (switchify d)

compile' :: [(Pat, Rhs)] -> Match DecisionDag
compile' = disjunct (Neg Set.empty)
compile :: [(Pat, Rhs)] -> Match DecisionTree'
compile = disjunct (Neg Set.empty)

disjunct :: Descr -> [(Pat, Rhs)] -> Match DecisionDag
disjunct :: Descr -> [(Pat, Rhs)] -> Match DecisionTree'
disjunct descr = \case
    [] -> do
        patStr <- view tpat >>= flip missingPat descr


@@ 108,22 111,24 @@ missingPat t descr = case t of
        missingPat' vs descr
    TFun _ _ -> pure "_"

missingPat' :: [(String, [Type])] -> Descr -> Match String
missingPat' :: [String] -> Descr -> Match String
missingPat' vs = \case
    Neg cs -> pure $ head $ Map.elems
        (Map.withoutKeys (allVariants vs) (Set.map variant cs))
    Pos con dargs ->
        let
            i = fromIntegral (variant con)
            (s, ts) = vs !! i
            s = if i < length vs
                then vs !! i
                else ice "variant >= type number of variants in missingPat'"
        in if null dargs
            then pure s
            else do
                ps <- zipWithM missingPat ts dargs
                ps <- zipWithM missingPat (argTs con) dargs
                pure ("(" ++ s ++ precalate " " ps ++ ")")

allVariants :: [(String, [Type])] -> Map VariantIx String
allVariants = Map.fromList . zip [0 ..] . map fst
allVariants :: [String] -> Map VariantIx String
allVariants = Map.fromList . zip [0 ..]

match
    :: Access


@@ 133,15 138,15 @@ match
    -> Rhs
    -> [(Pat, Rhs)]
    -> Pat
    -> Match DecisionDag
    -> Match DecisionTree'
match obj descr ctx work rhs rules = \case
    PVar x -> conjunct (augment descr ctx) (addBind x obj rhs) rules work
    PCon pcon pargs ->
        let
            disjunct' :: Descr -> Match DecisionDag
            disjunct' :: Descr -> Match DecisionTree'
            disjunct' newDescr = disjunct (buildDescr newDescr ctx work) rules

            conjunct' :: Match DecisionDag
            conjunct' :: Match DecisionTree'
            conjunct' = conjunct
                ((pcon, []) : ctx)
                rhs


@@ 149,22 154,24 @@ match obj descr ctx work rhs rules = \case
                ((pargs, getoargs, getdargs) : work)

            getoargs :: [Access]
            getoargs = args (\i -> Sel i obj)
            getoargs = args (\i -> Sel i (As obj (argTs pcon)))

            getdargs :: [Descr]
            getdargs = case descr of
                Neg _ -> args (const (Neg Set.empty))
                Pos _ dargs -> dargs

            args :: (Int -> a) -> [a]
            args f = map f ([0 .. arity pcon - 1])
            args :: (Word32 -> a) -> [a]
            args f = map f (take (arity pcon) [0 ..])
        in case staticMatch pcon descr of
            Yes -> conjunct'
            No -> disjunct' descr
            Maybe ->
                liftA2 (IfEq obj pcon) conjunct' (disjunct' (addneg pcon descr))
            Maybe -> do
                yes <- conjunct'
                no <- disjunct' (addneg pcon descr)
                pure (IfEq obj pcon yes no)

conjunct :: Ctx -> Rhs -> [(Pat, Rhs)] -> Work -> Match DecisionDag
conjunct :: Ctx -> Rhs -> [(Pat, Rhs)] -> Work -> Match DecisionTree'
conjunct ctx rhs@(casePos, binds, e) rules = \case
    [] -> caseReached casePos $> Success (binds, e)
    (work1 : workr) -> case work1 of


@@ 176,8 183,11 @@ conjunct ctx rhs@(casePos, binds, e) rules = \case
caseReached :: SrcPos -> Match ()
caseReached p = modify (delete p)

addBind :: String -> Access -> Rhs -> Rhs
addBind x obj (pos, binds, e) = (pos, (x, obj) : binds, e)
addBind :: TypedVar -> Access -> Rhs -> Rhs
addBind x obj (pos, binds, e) = (pos, Map.insert x obj binds, e)

arity :: Con -> Int
arity = length . argTs

buildDescr :: Descr -> Ctx -> Work -> Descr
buildDescr descr = curry $ \case


@@ 211,22 221,16 @@ addneg con = \case
    Neg nonset -> Neg (Set.insert con nonset)
    Pos _ _ -> ice "unexpected Pos in addneg"

data DecisionDag'
    = Leaf ([(String, Access)], Expr)
    | Switch Access (Map VariantIx DecisionDag') DecisionDag'
    deriving Show

switchify :: DecisionDag -> DecisionDag'
switchify :: DecisionTree' -> DecisionTree
switchify = \case
    Success e -> Leaf e
    IfEq obj con d0 d1 ->
        uncurry (Switch obj) (switchify' obj [(variant con, switchify d0)] d1)
    Success e -> DLeaf e
    d@(IfEq obj _ _ _) -> uncurry (DSwitch obj) (switchify' obj [] d)

switchify'
    :: Access
    -> [(VariantIx, DecisionDag')]
    -> DecisionDag
    -> (Map VariantIx DecisionDag', DecisionDag')
    -> [(VariantIx, DecisionTree)]
    -> DecisionTree'
    -> (Map VariantIx DecisionTree, DecisionTree)
switchify' obj rules = \case
    IfEq obj' con d0 d1 | obj == obj' ->
        switchify' obj ((variant con, switchify d0) : rules) d1

M src/Misc.hs => src/Misc.hs +4 -0
@@ 14,6 14,7 @@ module Misc
    , both
    , augment
    , insertWith'
    , if'
    )
where



@@ 84,3 85,6 @@ augment l = locally l . Map.union

insertWith' :: Ord k => (v -> v) -> k -> v -> Map k v -> Map k v
insertWith' f = Map.insertWith (f .* flip const)

if' :: Bool -> a -> a -> a
if' p c a = if p then c else a

M src/Mono.hs => src/Mono.hs +28 -11
@@ 93,17 93,31 @@ monoMatch e dt tbody =

monoDecisionTree :: An.DecisionTree -> Mono DecisionTree
monoDecisionTree = \case
    An.DecisionTree cs vdt -> do
        cs' <- mapM (bimapM (mapM monotype) monoDecisionTree) cs
        vdt' <- flip (maybe (pure Nothing)) vdt $ \(An.TypedVar x t, dt) -> do
            parentInst <- uses defInsts (Map.lookup x)
            modifying defInsts (Map.delete x)
            t' <- monotype t
            dt' <- monoDecisionTree dt
            maybe (pure ()) (modifying defInsts . Map.insert x) parentInst
            pure (Just (TypedVar x t', dt'))
        pure (DecisionTree cs' vdt')
    An.DecisionLeaf e -> fmap DecisionLeaf (mono e)
    An.DSwitch obj cs def -> do
        obj' <- monoAccess obj
        cs' <- mapM monoDecisionTree cs
        def' <- monoDecisionTree def
        pure (DSwitch obj' cs' def')
    An.DLeaf (bs, e) -> do
        let bs' = Map.toList bs
        let ks = map (\((An.TypedVar x _), _) -> x) bs'
        parentInsts <- uses defInsts (lookups ks)
        modifying defInsts (deletes ks)
        bs'' <- mapM
            (bimapM
                (\(An.TypedVar x t) -> fmap (TypedVar x) (monotype t))
                monoAccess
            )
            bs'
        e' <- mono e
        modifying defInsts (Map.union (Map.fromList parentInsts))
        pure (DLeaf (bs'', e'))

monoAccess :: An.Access -> Mono Access
monoAccess = \case
    An.Obj -> pure Obj
    An.As a ts -> liftA2 As (monoAccess a) (mapM monotype ts)
    An.Sel i a -> fmap (Sel i) (monoAccess a)

monoCtion :: An.Ction -> Mono Expr
monoCtion (i, (tdefName, tdefArgs), as) = do


@@ 167,3 181,6 @@ lookup' = Map.findWithDefault

lookups :: Ord k => [k] -> Map k v -> [(k, v)]
lookups ks m = catMaybes (map (\k -> fmap (k, ) (Map.lookup k m)) ks)

deletes :: (Foldable t, Ord k) => t k -> Map k v -> Map k v
deletes = flip (foldr Map.delete)

M src/MonoAst.hs => src/MonoAst.hs +15 -10
@@ 11,6 11,8 @@ module MonoAst
    , Const(..)
    , VariantIx
    , VariantTypes
    , Access(..)
    , VarBindings
    , DecisionTree(..)
    , Ction
    , Expr(..)


@@ 25,8 27,9 @@ import qualified Data.Map as Map
import Data.Map (Map)
import qualified Data.Set as Set
import Data.Set (Set)
import AnnotAst (VariantIx)
import Data.Word

import AnnotAst (VariantIx)
import FreeVars
import Ast (Const(..), TPrim(..))



@@ 43,11 46,15 @@ data TypedVar = TypedVar String Type

type VariantTypes = [Type]

data Access = Obj | As Access [Type] | Sel Word32 Access
    deriving (Show, Eq, Ord)

type VarBindings = [(TypedVar, Access)]

data DecisionTree
    = DecisionTree (Map VariantIx (VariantTypes, DecisionTree))
                   (Maybe (TypedVar, DecisionTree))
    | DecisionLeaf Expr
    deriving (Show)
    = DLeaf (VarBindings, Expr)
    | DSwitch Access (Map VariantIx DecisionTree) DecisionTree
    deriving Show

type Ction = (VariantIx, TConst, [Expr])



@@ 88,11 95,9 @@ fvExpr = \case

fvDecisionTree :: DecisionTree -> Set TypedVar
fvDecisionTree = \case
    DecisionTree cs vdt ->
        Set.unions
            $ maybe Set.empty (\(v, dt) -> Set.delete v (fvDecisionTree dt)) vdt
            : map (fvDecisionTree . snd) (Map.elems cs)
    DecisionLeaf e -> fvExpr e
    DSwitch _ cs def ->
        Set.unions $ fvDecisionTree def : map fvDecisionTree (Map.elems cs)
    DLeaf (bs, e) -> Set.difference (fvExpr e) (Set.fromList (map fst bs))

mainType :: Type
mainType = TFun (TPrim TUnit) (TPrim TUnit)

M src/Parse.hs => src/Parse.hs +1 -1
@@ 42,7 42,7 @@ import Data.Void
import Data.Composition
import Data.List

import Misc
import Misc hiding (if')
import SrcPos
import Ast
import NonEmpty

A src/Selections.hs => src/Selections.hs +54 -0
@@ 0,0 1,54 @@
{-# LANGUAGE LambdaCase, TupleSections #-}

module Selections (Selections, newSelections, select, selectVarBindings) where

import qualified Data.Map.Strict as Map
import Data.Map.Strict (Map)
import Data.Word
import Control.Monad

import Misc
import MonoAst


type Selections a = Map Access a


newSelections :: a -> Selections a
newSelections x = Map.singleton Obj x

select
    :: (Show a, Monad m)
    => ([Type] -> a -> m a)
    -> (Word32 -> a -> m a)
    -> Access
    -> Selections a
    -> m (a, Selections a)
select conv sub selector selections = case Map.lookup selector selections of
    Just a -> pure (a, selections)
    Nothing -> do
        (a, selections') <- case selector of
            Obj -> ice "select: Obj not in selections"
            As x ts -> do
                (a', s') <- select conv sub x selections
                a'' <- conv ts a'
                pure (a'', s')
            Sel i x -> do
                (a', s') <- select conv sub x selections
                a'' <- sub i a'
                pure (a'', s')
        pure (a, Map.insert selector a selections')

selectVarBindings
    :: (Show a, Monad m)
    => ([Type] -> a -> m a)
    -> (Word32 -> a -> m a)
    -> Selections a
    -> VarBindings
    -> m [(TypedVar, a)]
selectVarBindings conv sub selections = fmap fst . foldM
    (\(bs', ss) (x, s) -> do
        (a, ss') <- select conv sub s ss
        pure ((x, a) : bs', ss')
    )
    ([], selections)

M src/Subst.hs => src/Subst.hs +21 -59
@@ 1,27 1,18 @@
{-# LANGUAGE LambdaCase #-}

module Subst
    ( Subst
    , subst
    , substProgram
    , composeSubsts
    , VarSubst
    , substVExpr
    )
where
module Subst (Subst, subst, substProgram, substPat, composeSubsts) where

import qualified Data.Map.Strict as Map
import Data.Map.Strict (Map)
import Data.Bifunctor
import Data.Maybe

import Match
import AnnotAst

-- | Map of substitutions from type-variables to more specific types
type Subst = Map TVar Type

type VarSubst = (String, String)

substProgram :: Subst -> Program -> Program
substProgram s (Program main (Defs defs) tdefs) =
    Program (substExpr s main) (Defs (fmap (substDef s) defs)) tdefs


@@ 45,15 36,25 @@ substExpr s = \case

substDecisionTree :: Subst -> DecisionTree -> DecisionTree
substDecisionTree s = \case
    DecisionTree cs vdt -> DecisionTree
        (fmap (\(ts, dt) -> (map (subst s) ts, substDecisionTree s dt)) cs)
        (fmap
            (\(TypedVar x t, dt) ->
                (TypedVar x (subst s t), substDecisionTree s dt)
            )
            vdt
        )
    DecisionLeaf e -> DecisionLeaf (substExpr s e)
    DSwitch obj cs def -> DSwitch
        (substAccess s obj)
        (fmap (substDecisionTree s) cs)
        (substDecisionTree s def)
    DLeaf e -> DLeaf (second (substExpr s) e)

substAccess :: Subst -> Access -> Access
substAccess s = \case
    Obj -> Obj
    As a ts -> As (substAccess s a) (map (subst s) ts)
    Sel i a -> Sel i (substAccess s a)

substPat :: Subst -> Pat -> Pat
substPat s = \case
    PVar (TypedVar x t) -> PVar (TypedVar x (subst s t))
    PCon c ps -> PCon (substCon s c) (map (substPat s) ps)

substCon :: Subst -> Con -> Con
substCon s (Con ix sp ts) = Con ix sp (map (subst s) ts)

subst :: Subst -> Type -> Type
subst s t = case t of


@@ 64,42 65,3 @@ subst s t = case t of

composeSubsts :: Subst -> Subst -> Subst
composeSubsts s1 s2 = Map.union (fmap (subst s1) s2) s1

substVExpr :: VarSubst -> Expr -> Expr
substVExpr s = \case
    Lit c -> Lit c
    Var (TypedVar x t) -> Var (TypedVar (substV s x) t)
    App f a -> App (substVExpr s f) (substVExpr s a)
    If p c a -> If (substVExpr s p) (substVExpr s c) (substVExpr s a)
    Fun p b -> substVFun s p b
    Let (Defs defs) body -> substVLet s defs body
    Match e dt t -> Match (substVExpr s e) (substVDecisionTree s dt) t
    Ction (i, t, es) -> Ction (i, t, map (substVExpr s) es)

substVFun :: VarSubst -> (String, Type) -> (Expr, Type) -> Expr
substVFun s@(from, _) p@(p', _) b@(b', tb) =
    if p' == from then Fun p b else Fun p (substVExpr s b', tb)

substVLet :: VarSubst -> Map String (Scheme, Expr) -> Expr -> Expr
substVLet s@(from, _) defs body =
    let
        defs' = Map.mapWithKey
            (\k (scm, e) -> (scm, if from == k then e else substVExpr s e))
            defs
        body' = if Map.member from defs then body else substVExpr s body
    in Let (Defs defs') body'

substVDecisionTree :: VarSubst -> DecisionTree -> DecisionTree
substVDecisionTree s = \case
    DecisionTree cs vdt -> DecisionTree
        (fmap (\(ts, dt) -> (ts, substVDecisionTree s dt)) cs)
        (fmap
            (\(TypedVar x t, dt) ->
                (TypedVar (substV s x) t, substVDecisionTree s dt)
            )
            vdt
        )
    DecisionLeaf e -> DecisionLeaf (substVExpr s e)

substV :: VarSubst -> String -> String
substV (from, to) var = if var == from then to else var