~jojo/Carth

0ef0734e1ebaff842a5b418391d619b1613c4fdc — JoJo 2 years ago aef49fc
Improve what data about ctors&tdefs are passed from Check to Codegen

Codegen should have to do as little environment lookups and stuff as
possible, if it can be done earlier in Check and Mono. Now Check
manages to simplify what data is passed along for constructors. One
thing that is gone is the constructor name--only the index (which is
what matters) is preserved.
10 files changed, 287 insertions(+), 162 deletions(-)

M src/AnnotAst.hs
M src/Ast.hs
M src/Check.hs
M src/Codegen.hs
M src/Interp.hs
M src/Misc.hs
M src/Mono.hs
M src/MonoAst.hs
M src/Parse.hs
M src/TypeErr.hs
M src/AnnotAst.hs => src/AnnotAst.hs +15 -4
@@ 6,30 6,39 @@
module AnnotAst
    ( TVar(..)
    , TPrim(..)
    , TConst
    , Type(..)
    , Scheme(..)
    , TypedVar(..)
    , VariantIx
    , Pat(..)
    , Ctor
    , Const(..)
    , Expr(..)
    , Defs(..)
    , TypeDefs
    , Program(..)
    )
where

import Data.Map.Strict (Map)
import Data.Word

import Ast (TVar(..), TPrim(..), Type(..), Scheme(..), Const(..))
import Ast (TVar(..), TPrim(..), TConst, Type(..), Scheme(..), Const(..))


data TypedVar = TypedVar String Type
    deriving (Show, Eq, Ord)

type VariantIx = Word64

data Pat
    = PConstruction String [Pat]
    = PConstruction VariantIx [Pat]
    | PVar TypedVar
    deriving (Show, Eq)

type Ctor = (VariantIx, (String, [Type]), [Type])

data Expr
    = Lit Const
    | Var TypedVar


@@ 38,11 47,13 @@ data Expr
    | Fun (String, Type) (Expr, Type)
    | Let Defs Expr
    | Match Expr [(Pat, Expr)]
    | Constructor String
    | Ctor Ctor
    deriving (Show)

newtype Defs = Defs (Map String (Scheme, Expr))
    deriving (Show)

data Program = Program Expr Defs
type TypeDefs = Map String ([TVar], [[Type]])

data Program = Program Expr Defs TypeDefs
    deriving (Show)

M src/Ast.hs => src/Ast.hs +13 -10
@@ 4,6 4,7 @@
module Ast
    ( TVar(..)
    , TPrim(..)
    , TConst
    , Type(..)
    , Scheme(..)
    , scmParams


@@ 23,8 24,6 @@ where

import qualified Data.Set as Set
import Data.Set (Set)
import qualified Data.Map as Map
import Data.Map (Map)
import Data.List
import Data.Bifunctor
import Control.Lens (makeLenses)


@@ 52,10 51,12 @@ data TPrim
    | TBool
    deriving (Show, Eq, Ord)

type TConst = (String, [Type])

data Type
    = TVar TVar
    | TPrim TPrim
    | TConst String [Type]
    | TConst TConst
    | TFun Type Type
    deriving (Show, Eq, Ord)



@@ 96,10 97,10 @@ type Expr = WithPos Expr'

type Def = (Id, (Maybe (WithPos Scheme), Expr))

newtype ConstructorDefs = ConstructorDefs (Map String [Type])
newtype ConstructorDefs = ConstructorDefs [(Id, [Type])]
    deriving (Show, Eq)

data TypeDef = TypeDef String [Id] ConstructorDefs
data TypeDef = TypeDef Id [Id] ConstructorDefs
    deriving (Show, Eq)

data Program = Program [Def] [TypeDef]


@@ 185,7 186,9 @@ prettyProg d (Program defs tdefs) =
prettyTypeDef :: Int -> TypeDef -> String
prettyTypeDef d (TypeDef name params constrs) = concat
    [ "(type "
    , if null params then name else "(" ++ name ++ spcPretty params ++ ")"
    , if null params
        then pretty name
        else "(" ++ pretty name ++ spcPretty params ++ ")"
    , indent (d + 2) ++ pretty' (d + 2) constrs
    , ")"
    ]


@@ 193,11 196,11 @@ prettyTypeDef d (TypeDef name params constrs) = concat
prettyConstructorDefs :: Int -> ConstructorDefs -> String
prettyConstructorDefs d (ConstructorDefs cs) = intercalate
    ("\n" ++ indent d)
    (map prettyConstrDef (Map.toList cs))
    (map prettyConstrDef cs)
  where
    prettyConstrDef = \case
        (c, []) -> c
        (c, ts) -> concat ["(", c, " ", spcPretty ts, ")"]
        (c, []) -> pretty c
        (c, ts) -> concat ["(", pretty c, " ", spcPretty ts, ")"]

prettyExpr' :: Int -> Expr' -> String
prettyExpr' d = \case


@@ 279,7 282,7 @@ prettyType = \case
    Ast.TVar tv -> pretty tv
    Ast.TPrim c -> pretty c
    Ast.TFun a b -> prettyTFun a b
    Ast.TConst c ts -> case ts of
    Ast.TConst (c, ts) -> case ts of
        [] -> c
        ts -> concat ["(", c, " ", spcPretty ts, ")"]


M src/Check.hs => src/Check.hs +95 -64
@@ 4,7 4,7 @@
module Check (typecheck) where

import Control.Lens
    (Lens', (<<+=), assign, makeLenses, over, use, view, views, locally, mapped)
    ((<<+=), assign, makeLenses, over, use, view, views, locally, mapped)
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State.Strict


@@ 29,10 29,10 @@ import AnnotAst

data Env = Env
    { _envDefs :: Map String Scheme
    -- | Maps the name of an algebraic datatype to its definition
    , _envTypeDefs :: Map String Ast.TypeDef
    -- | Maps a constructor to the definition of the type it constructs
    , _envConstructors :: Map String Ast.TypeDef
    -- | Maps a constructor to its variant index in the type definition it
    --   constructs, the signature/left-hand-side of the type definition, and
    --   the types of its parameters
    , _envCtors :: Map String (VariantIx, (String, [TVar]), [Type])
    }
makeLenses ''Env



@@ 60,11 60,7 @@ runInfer' :: Infer a -> Either TypeErr (a, St)
runInfer' = runExcept . flip runStateT initSt . flip runReaderT initEnv

initEnv :: Env
initEnv = Env
    { _envDefs = builtinSchemes
    , _envTypeDefs = Map.empty
    , _envConstructors = Map.empty
    }
initEnv = Env { _envDefs = builtinSchemes, _envCtors = Map.empty }

builtinSchemes :: Map String Scheme
builtinSchemes = Map.fromList


@@ 85,24 81,11 @@ freshVar = fmap show fresh''
fresh'' :: Infer Int
fresh'' = tvCount <<+= 1

withTypes :: [Ast.TypeDef] -> Infer a -> Infer a
withTypes tds =
    let
        tds' = Map.fromList (map (\td@(Ast.TypeDef x _ _) -> (x, td)) tds)
        tdsCs = Map.fromList (concatMap extractCtors tds)
        extractCtors td@(Ast.TypeDef _ _ (Ast.ConstructorDefs cs)) =
            map (, td) (Map.keys cs)
    in augment envTypeDefs tds' . augment envConstructors tdsCs

augment
    :: (MonadReader e m, Ord k) => Lens' e (Map k v) -> Map k v -> m a -> m a
augment l = locally l . Map.union

withLocals :: [(String, Scheme)] -> Infer a -> Infer a
withLocals = withLocals' . Map.fromList

withLocals' :: Map String Scheme -> Infer a -> Infer a
withLocals' = locally envDefs . Map.union
withLocals' = augment envDefs

withLocal :: (String, Scheme) -> Infer a -> Infer a
withLocal b = locally envDefs (uncurry Map.insert b)


@@ 115,12 98,56 @@ inferProgram (Ast.Program defs tdefs) = do
        (throwError MainNotDefined)
        pure
        (lookup "main" (map (first unpos) defs))
    Defs defs' <- withTypes tdefs (inferDefs defs)
    (tdefs', ctors) <- checkTypeDefs tdefs
    Defs defs' <- augment envCtors ctors (inferDefs defs)
    let (Forall _ mainT, main) = fromJust (Map.lookup "main" defs')
    let expectedMainType = TFun (TPrim TUnit) (TPrim TUnit)
    unify (Expected expectedMainType) (Found mainPos mainT)
    let defs'' = Map.delete "main" defs'
    pure (Program main (Defs defs''))
    pure (Program main (Defs defs'') tdefs')

checkTypeDefs
    :: [Ast.TypeDef]
    -> Infer
           ( Map String ([TVar], [[Type]])
           , Map String (VariantIx, (String, [TVar]), [Type])
           )
checkTypeDefs =
    (fmap (second (fmap snd)) .)
        $ flip foldM (Map.empty, Map.empty)
        $ \(tds', csAcc) td@(Ast.TypeDef x _ _) -> do
            when (Map.member (idstr x) tds') (throwError (ConflictingTypeDef x))
            (td', cs) <- checkTypeDef td
            case listToMaybe (Map.elems (Map.intersection csAcc cs)) of
                Just (cId, _) -> throwError (ConflictingCtorDef cId)
                Nothing ->
                    pure (uncurry Map.insert td' tds', Map.union cs csAcc)

checkTypeDef
    :: Ast.TypeDef
    -> Infer
           ( (String, ([TVar], [[Type]]))
           , Map String (Id, (VariantIx, (String, [TVar]), [Type]))
           )
checkTypeDef (Ast.TypeDef (WithPos _ x) ps (Ast.ConstructorDefs cs)) = do
    let ps' = map TVExplicit ps
    let cs' = map snd cs
    cs''' <- foldM
        (\cs'' (i, (cx, cps)) -> if Map.member (idstr cx) cs''
            then throwError (ConflictingCtorDef cx)
            else pure (Map.insert (idstr cx) (cx, (i, (x, ps'), cps)) cs'')
        )
        Map.empty
        (zip [0 ..] cs)
    pure ((x, (ps', cs')), cs''')
--
-- withTypes tds =
--     let
--         tds' = Map.fromList (map (\td@(Ast.TypeDef x _ _) -> (x, td)) tds)
--         tdsCs = Map.fromList (concatMap extractCtors tds)
--         extractCtors td@(Ast.TypeDef _ _ (Ast.ConstructorDefs cs)) =
--             map (second (const td)) cs
--     in augment envTypeDefs tds' . augment envCtors tdsCs

inferDefs :: [Ast.Def] -> Infer Defs
inferDefs defs = do


@@ 149,17 176,16 @@ inferDefsComponents = \case
        let mayscms' = map (fmap unpos) mayscms
        let names = map idstr idents
        ts <- replicateM (length names) fresh
        let
            scms = map
        let scms = map
                (\(mayscm, t) -> fromMaybe (Forall Set.empty t) mayscm)
                (zip mayscms' ts)
        bodies' <-
            withLocals (zip names scms)
            $ forM (zip bodies scms)
            $ \(body, Forall _ t1) -> do
                  (t2, body') <- infer body
                  unify (Expected t1) (Found (getPos body) t2)
                  pure body'
                (t2, body') <- infer body
                unify (Expected t1) (Found (getPos body) t2)
                pure body'
        generalizeds <- mapM generalize ts
        let scms' = zipWith fromMaybe generalizeds mayscms'
        let annotDefs = Map.fromList (zip names (zip scms' bodies'))


@@ 255,16 281,17 @@ inferPat = \case
inferPatConstruction
    :: SrcPos -> Id -> [Ast.Pat] -> Infer (Type, Pat, Map Id Scheme)
inferPatConstruction pos c cArgs = do
    ctorOfTypeDef@(cParams, _) <- lookupEnvConstructor c
    (variantIx, tdefLhs, cParams) <- lookupEnvConstructor c
    let arity = length cParams
    let nArgs = length cArgs
    unless (arity == nArgs) (throwError (CtorArityMismatch pos c arity nArgs))
    (cParams', t) <- instantiateConstructorOfTypeDef ctorOfTypeDef
    (tdefInst, cParams') <- instantiateConstructorOfTypeDef tdefLhs cParams
    let t = TConst tdefInst
    (cArgTs, cArgs', cArgsVars) <- fmap unzip3 (mapM inferPat cArgs)
    cArgsVars' <- nonconflictingPatVarDefs cArgsVars
    forM_ (zip3 cParams' cArgTs cArgs) $ \(cParamT, cArgT, cArg) ->
        unify (Expected cParamT) (Found (getPos cArg) cArgT)
    pure (t, PConstruction (idstr c) cArgs', cArgsVars')
    pure (t, PConstruction variantIx cArgs', cArgsVars')

nonconflictingPatVarDefs :: [Map Id Scheme] -> Infer (Map Id Scheme)
nonconflictingPatVarDefs = flip foldM Map.empty $ \acc ks ->


@@ 274,33 301,36 @@ nonconflictingPatVarDefs = flip foldM Map.empty $ \acc ks ->

inferExprConstructor :: Id -> Infer (Type, Expr)
inferExprConstructor c = do
    ctorOfTypeDef <- lookupEnvConstructor c
    (cParams', t) <- instantiateConstructorOfTypeDef ctorOfTypeDef
    pure (foldr TFun t cParams', Constructor (idstr c))
    (variantIx, tdefLhs, cParams) <- lookupEnvConstructor c
    (tdefInst, cParams') <- instantiateConstructorOfTypeDef tdefLhs cParams
    pure
        ( foldr TFun (TConst tdefInst) cParams'
        , Ctor (variantIx, tdefInst, cParams')
        )

instantiateConstructorOfTypeDef
    :: ([Type], (String, [Id])) -> Infer ([Type], Type)
instantiateConstructorOfTypeDef (cParams, (tName, tParams)) = do
    :: (String, [TVar]) -> [Type] -> Infer (TConst, [Type])
instantiateConstructorOfTypeDef (tName, tParams) cParams = do
    tVars <- mapM (const fresh) tParams
    let tParams' = map TVExplicit tParams
    let cParams' = map (subst (Map.fromList (zip tParams' tVars))) cParams
    let t = TConst tName tVars
    pure (cParams', t)
    let cParams' = map (subst (Map.fromList (zip tParams tVars))) cParams
    pure ((tName, tVars), cParams')

lookupEnvConstructor :: Id -> Infer ([Type], (String, [Id]))
lookupEnvConstructor :: Id -> Infer (VariantIx, (String, [TVar]), [Type])
lookupEnvConstructor (WithPos pos cx) =
    views envConstructors (Map.lookup cx) >>= \case
        Just (Ast.TypeDef tx tps cs) ->
            case lookupConstructorParamTypes cx cs of
                Just cps -> pure (cps, (tx, tps))
                Nothing ->
                    ice
                        $ ("lookup failed for ctor `" ++ cx)
                        ++ ("` in type `" ++ tx ++ "`")
        Nothing -> throwError (UndefCtor pos cx)

lookupConstructorParamTypes :: String -> Ast.ConstructorDefs -> Maybe [Type]
lookupConstructorParamTypes cx (Ast.ConstructorDefs cs) = Map.lookup cx cs
    views envCtors (Map.lookup cx)
        >>= maybe (throwError (UndefCtor pos cx)) pure
    -- views envCtors (Map.lookup cx) >>= \case
    --     Just (Ast.TypeDef tx tps cs) ->
    --         case lookupConstructorParamTypes cx cs of
    --             Just cps -> pure (cps, (tx, tps))
    --             Nothing ->
    --                 ice
    --                     $ ("lookup failed for ctor `" ++ cx)
    --                     ++ ("` in type `" ++ tx ++ "`")
    --     Nothing -> throwError (UndefCtor pos cx)

-- lookupConstructorParamTypes :: String -> Ast.ConstructorDefs -> Maybe [Type]
-- lookupConstructorParamTypes cx (Ast.ConstructorDefs cs) = lookup cx cs

litType :: Const -> Type
litType = \case


@@ 319,8 349,8 @@ lookupEnv (WithPos pos x) = views envDefs (Map.lookup x) >>= \case
-- Substitution
--------------------------------------------------------------------------------
substProgram :: Subst -> Program -> Program
substProgram s (Program main (Defs defs)) =
    Program (substExpr s main) (Defs (fmap (substDef s) defs))
substProgram s (Program main (Defs defs) tdefs) =
    Program (substExpr s main) (Defs (fmap (substDef s) defs)) tdefs

substDef :: Subst -> (Scheme, Expr) -> (Scheme, Expr)
substDef s = bimap id (substExpr s)


@@ 337,7 367,7 @@ substExpr s = \case
    Match e cs -> Match
        (substExpr s e)
        (map (\(p, b) -> (substPat s p, substExpr s b)) cs)
    Constructor c -> Constructor c
    Ctor c -> Ctor c

substPat :: Subst -> Pat -> Pat
substPat s = \case


@@ 349,7 379,7 @@ subst s t = case t of
    TVar tv -> fromMaybe t (Map.lookup tv s)
    TPrim _ -> t
    TFun a b -> TFun (subst s a) (subst s b)
    TConst c ts -> TConst c (map (subst s) ts)
    TConst (c, ts) -> TConst (c, (map (subst s) ts))

substEnv :: Subst -> Env -> Env
substEnv s = over (envDefs . mapped . scmBody) (subst s)


@@ 381,9 411,10 @@ data UnifyErr'' = InfiniteType'' TVar Type | UnificationFailed'' Type Type
unify'' :: Type -> Type -> Except UnifyErr'' Subst
unify'' = curry $ \case
    (TPrim a, TPrim b) | a == b -> pure Map.empty
    (TConst c0 ts0, TConst c1 ts1) | c0 == c1 -> if length ts0 /= length ts1
        then ice "lengths of TConst params differ in unify"
        else unifys ts0 ts1
    (TConst (c0, ts0), TConst (c1, ts1)) | c0 == c1 ->
        if length ts0 /= length ts1
            then ice "lengths of TConst params differ in unify"
            else unifys ts0 ts1
    (TVar a, TVar b) | a == b -> pure Map.empty
    (TVar a, t) | occursIn a t -> throwError (InfiniteType'' a t)
    -- Do not allow "override" of explicit (user given) type variables.


@@ 428,7 459,7 @@ ftv = \case
    TVar tv -> Set.singleton tv
    TPrim _ -> Set.empty
    TFun t1 t2 -> Set.union (ftv t1) (ftv t2)
    TConst _ ts -> Set.unions (map ftv ts)
    TConst (_, ts) -> Set.unions (map ftv ts)

ftvEnv :: Env -> Set TVar
ftvEnv = Set.unions . map (ftvScheme . snd) . Map.toList . view envDefs

M src/Codegen.hs => src/Codegen.hs +33 -20
@@ 54,10 54,9 @@ import qualified MonoAst
import MonoAst hiding (Type, Const)
import qualified SizeOf


-- | An instruction that returns a value. The name refers to the fact that a
-- mathematical function always returns a value, but an imperative procedure may
-- only produce side effects.
--   mathematical function always returns a value, but an imperative procedure
--   may only produce side effects.
data FunInstruction = WithRetType Instruction Type

-- TODO: Either treat globals and locals separately - Globals behing pointers,


@@ 82,7 81,8 @@ type Gen' = StateT St (ReaderT Env IO)
data Out = Out
    { _outBlocks :: [BasicBlock]
    , _outStrings :: [(Name, String)]
    , _outFuncs :: [(Name, [TypedVar], TypedVar, Expr)]}
    , _outFuncs :: [(Name, [TypedVar], TypedVar, Expr)]
    }
makeLenses ''Out

type Gen = WriterT Out Gen'


@@ 103,12 103,14 @@ instance Pretty Module where


codegen :: Context -> FilePath -> Program -> IO Module
codegen context moduleFilePath (Program main (Defs defs)) = do
codegen context moduleFilePath (Program main (Defs defs) tdefs) = do
    let defs' = (TypedVar "main" mainType, main) : Map.toList defs
        genGlobDefs = withGlobDefSigs
            defs'
            (liftA2 (:) genOuterMain (fmap join (mapM genGlobDef defs')))
    globDefs <- runGen' context genGlobDefs
    globDefs <- runGen'
        context
        (liftA2 (++) (mapM genTypeDef tdefs) genGlobDefs)
    pure Module
        { moduleName = fromString ((takeBaseName moduleFilePath))
        , moduleSourceFileName = fromString moduleFilePath


@@ 129,6 131,14 @@ initSt = St
    , _registerCount = 0
    }

genTypeDef :: (TConst, [[MonoAst.Type]]) -> Gen' Definition
genTypeDef (lhs, variants) = do
    let name = mkName (mangleTConst lhs)
    let ts = map (typeStruct . map toLlvmType) variants
    sizedTs <- mapM (\t -> fmap (, t) (sizeof t)) ts
    let (_, tmax) = maximum sizedTs
    pure (TypeDefinition name (Just tmax))

genBuiltins :: [Definition]
genBuiltins = map
    (GlobalDefinition . ($ []))


@@ 218,7 228,7 @@ genExpr = \case
    Fun p b -> genLambda p b
    Let ds b -> genLet ds b
    Match e cs -> genMatch e cs
    Constructor _ -> nyi "genExpr Constructor"
    Ctor _ -> nyi "genExpr Constructor"

-- | Convert to the LLVM representation of a type in an expression-context.
toLlvmType :: MonoAst.Type -> Type


@@ 234,7 244,7 @@ toLlvmType = \case
        [ LLType.ptr typeUnit
        , LLType.ptr (typeClosureFun (toLlvmType a) (toLlvmType r))
        ]
    t@(TConst _ _) -> typeNamed (mangleType t)
    TConst t -> typeNamed (mangleTConst t)

genConst :: MonoAst.Const -> Gen LLConst.Constant
genConst = \case


@@ 328,7 338,7 @@ genMatchPattern _nextCaseL m = \case
        pure [(var, n)]

withDefSigs :: [(TypedVar, Name)] -> Gen a -> Gen a
withDefSigs = locally localEnv . Map.union . Map.fromList . map
withDefSigs = augment localEnv . Map.fromList . map
    (\(v@(TypedVar _ t), n') -> (v, LocalReference (toLlvmType t) n'))

-- TODO: Change global defs to a new type that can be generated by llvm.  As it


@@ 374,7 384,7 @@ genStruct xs = do
genBoxGeneric :: Operand -> Gen Operand
genBoxGeneric x = do
    let t = typeOf x
    ptrGeneric <- genMalloc =<< sizeof'' t
    ptrGeneric <- genMalloc =<< genSizeof t
    ptr <- emitAnon (bitcast ptrGeneric (LLType.ptr t))
    emit (store x ptr)
    pure ptrGeneric


@@ 451,19 461,19 @@ simpleGlobVar name t init = GlobalVariable
parameter :: Name -> Type -> LLGlob.Parameter
parameter p pt = LLGlob.Parameter pt p []

sizeof'' :: Type -> Gen Operand
sizeof'' = fmap ConstantOperand . sizeof'
genSizeof :: Type -> Gen Operand
genSizeof = fmap (ConstantOperand . litI64 . fromIntegral) . sizeof'

sizeof' :: Type -> Gen LLConst.Constant
sizeof' = fmap (litI64 . fromIntegral) . sizeof
sizeof' :: Type -> Gen Word64
sizeof' = lift . sizeof

sizeof :: Type -> Gen Word64
sizeof :: Type -> Gen' Word64
sizeof t = do
    c <- view ctx
    liftIO (SizeOf.sizeof c t)

withGlobDefSigs :: MonadReader Env m => [(TypedVar, Expr)] -> m a -> m a
withGlobDefSigs = locally globalEnv . Map.union . Map.fromList . map
withGlobDefSigs = augment globalEnv . Map.fromList . map
    (\(v@(TypedVar _ t), _) ->
        ( v
        , ConstantOperand


@@ 664,7 674,7 @@ getMembers = \case
    t -> ice $ "Tried to get member types of non-struct type " ++ pretty t

getIndexed :: Type -> [Word32] -> Type
getIndexed = foldl (\t i -> getMembers t !! (fromIntegral i))
getIndexed = foldl' (\t i -> getMembers t !! (fromIntegral i))

mangleName :: TypedVar -> Name
mangleName (TypedVar x t) = mkName (x ++ ":" ++ mangleType t)


@@ 672,9 682,12 @@ mangleName (TypedVar x t) = mkName (x ++ ":" ++ mangleType t)
mangleType :: MonoAst.Type -> String
mangleType = \case
    TPrim c -> pretty c
    TFun p r -> mangleType (TConst "->" [p, r])
    TConst c ts ->
        concat ["(", c, ",", intercalate "," (map mangleType ts), ")"]
    TFun p r -> mangleTConst ("->", [p, r])
    TConst tc -> mangleTConst tc

mangleTConst :: TConst -> String
mangleTConst (c, ts) =
    concat ["(", c, ",", intercalate "," (map mangleType ts), ")"]

unName :: Name -> ShortByteString
unName = \case

M src/Interp.hs => src/Interp.hs +4 -4
@@ 16,7 16,7 @@ import MonoAst
data Val
    = VConst Const
    | VFun (Val -> IO Val)
    | VConstruction String
    | VConstruction VariantIx
                    [Val] -- ^ Arguments are in reverse order--last arg first

type Env = Map TypedVar Val


@@ 43,7 43,7 @@ plus :: Val -> Val -> Val
plus a b = VConst (Int (unwrapInt a + unwrapInt b))

evalProgram :: Program -> Eval ()
evalProgram (Program main defs) = do
evalProgram (Program main defs _) = do
    f <- evalLet defs main
    fmap unwrapUnit (unwrapFun' f (VConst Unit))



@@ 65,7 65,7 @@ eval = \case
        pure (VFun (\v -> runEval (withLocals env (withLocal p v (eval b)))))
    Let defs body -> evalLet defs body
    Match e cs -> eval e >>= flip evalCases cs
    Constructor c -> pure (VConstruction c [])
    Ctor (i, _, _) -> pure (VConstruction i [])

evalApp :: Expr -> Expr -> Eval Val
evalApp ef ea = eval ef >>= \case


@@ 138,4 138,4 @@ showVariant = \case
        Bool _ -> "bool"
        Char _ -> "character"
    VFun _ -> "function"
    VConstruction c _ -> "construction of " ++ c
    VConstruction c _ -> "construction of variant " ++ show c

M src/Misc.hs => src/Misc.hs +10 -1
@@ 1,4 1,4 @@
{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, LambdaCase, RankNTypes #-}

module Misc
    ( ice


@@ 12,10 12,15 @@ module Misc
    , showChar''
    , showChar'
    , both
    , augment
    )
where

import Data.List (intercalate)
import qualified Data.Map as Map
import Data.Map (Map)
import Control.Monad.Reader
import Control.Lens (Lens', locally)

ice :: String -> a
ice = error . ("Internal Compiler Error: " ++)


@@ 70,3 75,7 @@ showChar' c = "'" ++ showChar'' c ++ "'"

both :: (a -> b) -> (a, a) -> (b, b)
both f (a0, a1) = (f a0, f a1)

augment
    :: (MonadReader e m, Ord k) => Lens' e (Map k v) -> Map k v -> m a -> m a
augment l = locally l . Map.union

M src/Mono.hs => src/Mono.hs +70 -35
@@ 5,9 5,10 @@
module Mono (monomorphize) where

import Control.Applicative (liftA2, liftA3)
import Control.Lens (makeLenses, over, views)
import Control.Lens (makeLenses, views, use, uses, modifying)
import Control.Monad.Reader
import Control.Monad.State
import Data.Functor
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.Maybe


@@ 23,17 24,25 @@ data Env = Env
    { _defs :: Map String (Scheme, An.Expr)
    , _tvBinds :: Map TVar Type
    }

makeLenses ''Env

type Insts = Map String (Map Type Expr)
data Insts = Insts
    { _defInsts :: Map String (Map Type Expr)
    , _tdefInsts :: Set TConst
    }
makeLenses ''Insts

-- | The monomorphization monad
type Mono = ReaderT Env (State Insts)
type Mono = StateT Insts (Reader Env)

monomorphize :: An.Program -> Program
monomorphize (An.Program main ds) = (uncurry (flip Program))
    (evalState (runReaderT (monoLet ds main) initEnv) Map.empty)
monomorphize (An.Program main defs tdefs) =
    let
        initInsts = Insts Map.empty Set.empty
        ((defs', main'), Insts _ tdefInsts') =
            runReader (runStateT (monoLet defs main) initInsts) initEnv
        tdefs' = instTypeDefs tdefs tdefInsts'
    in Program main' defs' tdefs'

initEnv :: Env
initEnv = Env { _defs = Map.empty, _tvBinds = Map.empty }


@@ 43,36 52,35 @@ mono = \case
    An.Lit c -> pure (Lit c)
    An.Var (An.TypedVar x t) -> do
        t' <- monotype t
        addInst x t'
        addDefInst x t'
        pure (Var (TypedVar x t'))
    An.App f a -> liftA2 App (mono f) (mono a)
    An.If p c a -> liftA3 If (mono p) (mono c) (mono a)
    An.Fun p b -> monoFun p b
    An.Let ds b -> fmap (uncurry Let) (monoLet ds b)
    An.Match e cs -> monoMatch e cs
    An.Constructor c -> pure (Constructor c)
    An.Ctor c -> monoCtor c

monoFun :: (String, An.Type) -> (An.Expr, An.Type) -> Mono Expr
monoFun (p, tp) (b, bt) = do
    parentInst <- gets (Map.lookup p)
    modify (Map.delete p)
    parentInst <- uses defInsts (Map.lookup p)
    modifying defInsts (Map.delete p)
    tp' <- monotype tp
    b' <- mono b
    bt' <- monotype bt
    maybe (pure ()) (modify . Map.insert p) parentInst
    maybe (pure ()) (modifying defInsts . Map.insert p) parentInst
    pure (Fun (TypedVar p tp') (b', bt'))

monoLet :: An.Defs -> An.Expr -> Mono (Defs, Expr)
monoLet (An.Defs ds) body = do
    let ks = Map.keys ds
    parentInsts <- gets (lookups ks)
    parentInsts <- uses defInsts (lookups ks)
    let newEmptyInsts = (fmap (const Map.empty) ds)
    modify (Map.union newEmptyInsts)
    body' <- local (over defs (Map.union ds)) (mono body)
    dsInsts <- gets (lookups ks)
    modify (Map.union (Map.fromList parentInsts))
    let
        ds' = Map.fromList $ do
    modifying defInsts (Map.union newEmptyInsts)
    body' <- augment defs ds (mono body)
    dsInsts <- uses defInsts (lookups ks)
    modifying defInsts (Map.union (Map.fromList parentInsts))
    let ds' = Map.fromList $ do
            (name, dInsts) <- dsInsts
            (t, body) <- Map.toList dInsts
            pure (TypedVar name t, body)


@@ 88,11 96,10 @@ monoCase :: (An.Pat, An.Expr) -> Mono (Pat, Expr)
monoCase (p, e) = do
    (p', pvs) <- monoPat p
    let pvs' = Set.toList pvs
    -- let pvs = patternBoundVars p :: Set An.TypedVar
    parentInsts <- gets (lookups pvs')
    modify (deletes pvs')
    parentInsts <- uses defInsts (lookups pvs')
    modifying defInsts (deletes pvs')
    e' <- mono e
    modify (Map.union (Map.fromList parentInsts))
    modifying defInsts (Map.union (Map.fromList parentInsts))
    pure (p', e')

monoPat :: An.Pat -> Mono (Pat, Set String)


@@ 103,19 110,24 @@ monoPat = \case
    An.PVar (An.TypedVar x t) ->
        fmap (\t' -> (PVar (TypedVar x t'), Set.singleton x)) (monotype t)

addInst :: String -> Type -> Mono ()
addInst x t1 = do
    insts <- get
    case Map.lookup x insts of
monoCtor :: An.Ctor -> Mono Expr
monoCtor (i, (tdefName, tdefArgs), ts) = do
    tdefArgs' <- mapM monotype tdefArgs
    let tdefInst = (tdefName, tdefArgs')
    modifying tdefInsts (Set.insert tdefInst)
    ts' <- mapM monotype ts
    pure (Ctor (i, tdefInst, ts'))

addDefInst :: String -> Type -> Mono ()
addDefInst x t1 = do
    use defInsts <&> Map.lookup x >>= \case
        -- If x is not in insts, it's a function parameter. Ignore.
        Nothing -> pure ()
        Just xInsts -> unless (Map.member t1 xInsts) $ do
            (Forall _ t2, body) <- views
                defs
                (lookup' (ice (x ++ " not in defs")) x)
            body' <- local
                (over tvBinds (Map.union (bindTvs t2 t1)))
                (mono body)
            body' <- augment tvBinds (bindTvs t2 t1) (mono body)
            insertInst x t1 body'

bindTvs :: An.Type -> Type -> Map TVar Type


@@ 123,21 135,44 @@ bindTvs a b = case (a, b) of
    (An.TVar v, t) -> Map.singleton v t
    (An.TFun p0 r0, TFun p1 r1) -> Map.union (bindTvs p0 p1) (bindTvs r0 r1)
    (An.TPrim _, TPrim _) -> Map.empty
    (An.TConst _ ts0, TConst _ ts1) -> Map.unions (zipWith bindTvs ts0 ts1)
    (An.TConst (_, ts0), TConst (_, ts1)) ->
        Map.unions (zipWith bindTvs ts0 ts1)
    (An.TPrim _, _) -> err
    (An.TFun _ _, _) -> err
    (An.TConst _ _, _) -> err
    (An.TConst _, _) -> err
    where err = ice $ "bindTvs: " ++ show a ++ ", " ++ show b

monotype :: An.Type -> Mono Type
monotype = \case
monotype = lift . monotype'

monotype' :: An.Type -> Reader Env Type
monotype' = \case
    An.TVar v -> views tvBinds (lookup' (ice (show v ++ " not in tvBinds")) v)
    An.TPrim c -> pure (TPrim c)
    An.TFun a b -> liftA2 TFun (monotype a) (monotype b)
    An.TConst c ts -> fmap (TConst c) (mapM monotype ts)
    An.TFun a b -> liftA2 TFun (monotype' a) (monotype' b)
    An.TConst (c, ts) -> fmap (curry TConst c) (mapM monotype' ts)

insertInst :: String -> Type -> Expr -> Mono ()
insertInst x t b = modify (Map.adjust (Map.insert t b) x)
insertInst x t b = modifying defInsts (Map.adjust (Map.insert t b) x)

-- Anot: [(String, ([TVar], [[Type]]))]
-- Mono: [(TConst, [[Type]])]
--
-- Env
--    { _defs :: Map String (Scheme, An.Expr)
--    , _tvBinds :: Map TVar Type
--    }
instTypeDefs :: An.TypeDefs -> Set TConst -> TypeDefs
instTypeDefs tdefs insts = map
    (\(x, ts) -> instTypeDef x ts (lookup' (ice "in instTypeDefs") x tdefs))
    (Set.toList insts)
  where
    instTypeDef x ts (tvs, vs) =
        let
            vs' = runReader
                (mapM (mapM monotype') vs)
                (Env Map.empty (Map.fromList (zip tvs ts)))
        in ((x, ts), vs')

lookup' :: Ord k => v -> k -> Map k v -> v
lookup' = Map.findWithDefault

M src/MonoAst.hs => src/MonoAst.hs +25 -10
@@ 5,12 5,16 @@

module MonoAst
    ( TPrim(..)
    , TConst
    , Type(..)
    , TypedVar(..)
    , VariantIx
    , Pat(..)
    , Ctor
    , Const(..)
    , Expr(..)
    , Defs(..)
    , TypeDefs
    , Program(..)
    , mainType
    )


@@ 20,24 24,30 @@ import qualified Data.Map as Map
import Data.Map (Map)
import qualified Data.Set as Set
import Data.Set (Set)
import AnnotAst (VariantIx)

import FreeVars
import Ast (Const(..), TPrim(..))

type TConst = (String, [Type])

data Type
    = TPrim TPrim
    | TFun Type Type
    | TConst String [Type]
    | TConst TConst
    deriving (Show, Eq, Ord)

data TypedVar = TypedVar String Type
    deriving (Show, Eq, Ord)

data Pat
    = PConstruction String [Pat]
    = PConstruction VariantIx [Pat]
    | PVar TypedVar
    deriving (Show, Eq)

-- | (Variant index, constructed type, innter type of this variant)
type Ctor = (VariantIx, TConst, [Type])

data Expr
    = Lit Const
    | Var TypedVar


@@ 46,21 56,26 @@ data Expr
    | Fun TypedVar (Expr, Type)
    | Let Defs Expr
    | Match Expr [(Pat, Expr)]
    | Constructor String
    | Ctor Ctor
    deriving (Show)

newtype Defs = Defs (Map TypedVar Expr)
    deriving (Show)

data Program = Program Expr Defs
type Variant = [Type]
type TypeDefs = [(TConst, [Variant])]

data Program = Program Expr Defs TypeDefs
    deriving (Show)

mainType :: Type
mainType = TFun (TPrim TUnit) (TPrim TUnit)

instance FreeVars Expr TypedVar where
    freeVars = fvExpr

instance Pattern Pat TypedVar where
    patternBoundVars = bvPat


fvExpr :: Expr -> Set TypedVar
fvExpr = \case
    Lit _ -> Set.empty


@@ 70,12 85,12 @@ fvExpr = \case
    Fun p (b, _) -> fvFun p b
    Let (Defs bs) e -> fvLet (Map.keysSet bs, Map.elems bs) e
    Match e cs -> fvMatch e cs
    Constructor _ -> Set.empty

instance Pattern Pat TypedVar where
    patternBoundVars = bvPat
    Ctor _ -> Set.empty

bvPat :: Pat -> Set TypedVar
bvPat = \case
    PConstruction _ ps -> Set.unions (map bvPat ps)
    PVar x -> Set.singleton x

mainType :: Type
mainType = TFun (TPrim TUnit) (TPrim TUnit)

M src/Parse.hs => src/Parse.hs +15 -13
@@ 18,6 18,7 @@ module Parse
    , var
    , eConstructor
    , ns_expr
    , ns_big
    )
where



@@ 33,10 34,10 @@ import Text.Megaparsec.Char hiding (space, space1)
import qualified Text.Megaparsec.Char as Char
import qualified Text.Megaparsec.Char.Lexer as Lexer
import qualified Data.Set as Set
import qualified Data.Map as Map
import Data.Either.Combinators
import Data.Void
import Data.Composition
import Data.List

import Misc
import SrcPos


@@ 72,11 73,11 @@ toplevel = do
typedef :: Parser TypeDef
typedef = do
    _ <- reserved "type"
    let onlyName = fmap (, []) big
    let nameAndSome = parens . liftA2 (,) big . some
    let onlyName = fmap (, []) big'
    let nameAndSome = parens . liftA2 (,) big' . some
    (name, params) <- onlyName <|> nameAndSome small'
    constrs <- many (onlyName <|> nameAndSome type_)
    pure (TypeDef name params (ConstructorDefs (Map.fromList constrs)))
    pure (TypeDef name params (ConstructorDefs constrs))

def :: SrcPos -> Parser Def
def topPos = defUntyped topPos <|> defTyped topPos


@@ 118,8 119,7 @@ ns_expr = withPos
        a <- eitherP
            (try (Lexer.decimal <* notFollowedBy (char '.')))
            Lexer.float
        let
            e = either
        let e = either
                (\n -> Int (if neg then -n else n))
                (\x -> Double (if neg then -x else x))
                a


@@ 127,11 127,13 @@ ns_expr = withPos
    charLit = fmap
        (Lit . Char)
        (between (char '\'') (char '\'') Lexer.charLiteral)
    str = fmap (Lit . Str) (char '"' >> manyTill Lexer.charLiteral (char '"'))
    str =
        fmap (Lit . Str) $ char '"' >> manyTill Lexer.charLiteral (char '"')
    bool = do
        b <- (ns_reserved "true" $> True) <|> (ns_reserved "false" $> False)
        pure (Lit (Bool b))
    pexpr = ns_parens (choice [funMatch, match, if', fun, let', typeAscr, app])
    pexpr = ns_parens
        (choice [funMatch, match, if', fun, let', typeAscr, app])

eConstructor :: Parser Expr'
eConstructor = fmap Constructor ns_big'


@@ 173,7 175,7 @@ app :: Parser Expr'
app = do
    rator <- expr
    rands <- some expr
    pure (unpos (foldl (WithPos (getPos rator) .* App) rator rands))
    pure (unpos (foldl' (WithPos (getPos rator) .* App) rator rands))

if' :: Parser Expr'
if' = do


@@ 203,8 205,8 @@ let' = do
binding :: Parser Def
binding = parens (bindingTyped <|> bindingUntyped)
  where
    bindingTyped =
        reserved ":" *> liftA2 (,) small' (liftA2 (,) (fmap Just scheme) expr)
    bindingTyped = reserved ":"
        *> liftA2 (,) small' (liftA2 (,) (fmap Just scheme) expr)
    bindingUntyped = liftA2 (,) small' (fmap (Nothing, ) expr)

typeAscr :: Parser Expr'


@@ 229,7 231,7 @@ nonptype = andSkipSpaceAfter ns_nonptype

ns_nonptype :: Parser Type
ns_nonptype = choice
    [fmap TPrim ns_tprim, fmap TVar ns_tvar, fmap (flip TConst []) ns_big]
    [fmap TPrim ns_tprim, fmap TVar ns_tvar, fmap (TConst . (, [])) ns_big]

ptype :: Parser Type
ptype = parens ptype'


@@ 238,7 240,7 @@ ptype' :: Parser Type
ptype' = tfun <|> tapp

tapp :: Parser Type
tapp = liftA2 TConst big (some type_)
tapp = liftA2 (TConst .* (,)) big (some type_)

tfun :: Parser Type
tfun = do

M src/TypeErr.hs => src/TypeErr.hs +7 -1
@@ 20,6 20,8 @@ data TypeErr
    | UndefVar SrcPos String
    | InfType SrcPos TVar Type
    | UnificationFailed SrcPos Type Type Type Type
    | ConflictingTypeDef Id
    | ConflictingCtorDef Id

type Message = String



@@ 38,7 40,7 @@ prettyErr = \case
    ConflictingPatVarDefs p v ->
        posd p var
            $ "Conflicting definitions for variable `"
            ++ pretty v
            ++ v
            ++ "` in pattern."
    UndefCtor p c ->
        posd p eConstructor $ "Undefined constructor `" ++ c ++ "`"


@@ 50,6 52,10 @@ prettyErr = \case
            $ ("Couldn't match type " ++ pretty t'2 ++ " with " ++ pretty t'1)
            ++ (".\nExpected type: " ++ pretty t1)
            ++ (".\nFound type: " ++ pretty t2 ++ ".")
    ConflictingTypeDef (WithPos p x) ->
        posd p ns_big $ "Conflicting definitions for type `" ++ x ++ "`."
    ConflictingCtorDef (WithPos p x) ->
        posd p ns_big $ "Conflicting definitions for constructor `" ++ x ++ "`."

posd :: SrcPos -> Parser a -> Message -> Source -> String
posd (SrcPos pos@(SourcePos _ lineN colN)) parser msg src =