~jojo/Carth

a012826829b8f8aa4129ec1b9f0d5babcfe45dff — JoJo 1 year, 10 months ago b7f7229
Check user-written types in Infer

Basically, do `checkType` for user-written types. It detects
references to undefined types. Also, separete AnnotAst.Type from
Ast.Type to help enforce that this is done. This whole ordeal required
some refactoring.
12 files changed, 317 insertions(+), 268 deletions(-)

M app/Main.hs
M src/AnnotAst.hs
M src/Ast.hs
M src/Check.hs
M src/Codegen.hs
M src/Infer.hs
M src/Match.hs
M src/Misc.hs
M src/Parse.hs
R src/{PrettyAst.hs => Pretty.hs}
M src/SrcPos.hs
M src/TypeErr.hs
M app/Main.hs => app/Main.hs +1 -0
@@ 6,6 6,7 @@ import System.Environment
import Control.Monad

import Misc
import Pretty
import qualified TypeErr
import qualified Ast
import qualified DesugaredAst

M src/AnnotAst.hs => src/AnnotAst.hs +27 -4
@@ 1,4 1,4 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase, TemplateHaskell #-}

-- | Type annotated AST as a result of typechecking
module AnnotAst


@@ 8,6 8,8 @@ module AnnotAst
    , TConst
    , Type(..)
    , Scheme(..)
    , scmParams
    , scmBody
    , Id
    , TypedVar(..)
    , Const(..)


@@ 28,13 30,30 @@ module AnnotAst
    )
where

import Data.Set (Set)
import Data.Map.Strict (Map)
import Lens.Micro.Platform (makeLenses)

import Ast
    (TVar(..), TPrim(..), TConst, Type(..), Scheme(..), Const(..), startType)
import Ast (TVar(..), TPrim(..), Const(..))
import SrcPos


type TConst = (String, [Type])

data Type
    = TVar TVar
    | TPrim TPrim
    | TConst TConst
    | TFun Type Type
    | TBox Type
    deriving (Show, Eq, Ord)

data Scheme = Forall
    { _scmParams :: (Set TVar)
    , _scmBody :: Type
    } deriving (Show, Eq)
makeLenses ''Scheme

type Id = WithPos String

data TypedVar = TypedVar Id Type


@@ 79,7 98,7 @@ data Expr'
type Expr = WithPos Expr'

type Defs = Map String (Scheme, Expr)
type TypeDefs = Map String ([TVar], [(String, [Type])])
type TypeDefs = Map String ([TVar], [(Id, [Type])])
type Ctors = Map String (VariantIx, (String, [TVar]), [Type], Span)
type Externs = Map String Type



@@ 89,3 108,7 @@ instance Eq Con where

instance Ord Con where
    compare (Con c1 _ _) (Con c2 _ _) = compare c1 c2


startType :: Type
startType = TFun (TPrim TUnit) (TPrim TUnit)

M src/Ast.hs => src/Ast.hs +8 -12
@@ 1,6 1,5 @@
{-# LANGUAGE LambdaCase, TypeSynonymInstances, FlexibleInstances
           , MultiParamTypeClasses, TemplateHaskell, KindSignatures
           , DataKinds #-}
           , MultiParamTypeClasses, KindSignatures, DataKinds #-}

module Ast
    ( TVar(..)


@@ 8,8 7,6 @@ module Ast
    , TConst
    , Type(..)
    , Scheme(..)
    , scmParams
    , scmBody
    , IdCase(..)
    , Id(..)
    , idstr


@@ 22,14 19,13 @@ module Ast
    , TypeDef(..)
    , Extern(..)
    , Program(..)
    , startType
    , isFunLike
    , startType
    )
where

import qualified Data.Set as Set
import Data.Set (Set)
import Lens.Micro.Platform (makeLenses)
import Control.Arrow ((>>>))

import SrcPos


@@ 62,6 58,9 @@ data TPrim

type TConst = (String, [Type])

-- TODO: Now that AnnotAst.Type is not just an alias to Ast.Type, it makes sense
--       to add SrcPos-itions to Ast.Type! Would simplify / improve error
--       messages quite a bit.
data Type
    = TVar TVar
    | TPrim TPrim


@@ 70,11 69,8 @@ data Type
    | TBox Type
    deriving (Show, Eq, Ord)

data Scheme = Forall
    { _scmParams :: (Set TVar)
    , _scmBody :: Type
    } deriving (Show, Eq)
makeLenses ''Scheme
data Scheme = Forall SrcPos (Set TVar) Type
     deriving (Show, Eq)

data Pat
    = PConstruction SrcPos (Id 'Big) [Pat]


@@ 110,7 106,7 @@ data Expr'

type Expr = WithPos Expr'

type Def = (Id 'Small, (Maybe (WithPos Scheme), Expr))
type Def = (Id 'Small, (Maybe Scheme, Expr))

newtype ConstructorDefs = ConstructorDefs [(Id 'Big, [Type])]
    deriving (Show, Eq)

M src/Check.hs => src/Check.hs +89 -104
@@ 5,6 5,7 @@ module Check (typecheck) where
import Prelude hiding (span)
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State.Strict
import Data.Bifunctor
import Data.Bitraversable
import Data.Foldable


@@ 13,13 14,12 @@ import qualified Data.Map.Strict as Map
import Data.Map.Strict (Map)
import qualified Data.Set as Set
import Data.Set (Set)
import Data.Maybe

import Misc
import SrcPos
import Subst
import qualified Ast
import Ast (Id(..), IdCase(..), idstr, Type(..), TVar(..), TPrim(..))
import Ast (Id(..), IdCase(..), TVar(..), TPrim(..), idstr)
import TypeErr
import AnnotAst (VariantIx)
import qualified AnnotAst as An


@@ 34,7 34,7 @@ typecheck (Ast.Program defs tdefs externs) = runExcept $ do
    (externs', inferred, substs) <- inferTopDefs tdefs' ctors externs defs
    let substd = substTopDefs substs inferred
    checkTypeVarsBound substd
    let mTypeDefs = fmap (map fst . snd) tdefs'
    let mTypeDefs = fmap (map (unpos . fst) . snd) tdefs'
    desugared <- compileDecisionTreesAndDesugar mTypeDefs substd
    checkStartDefined desugared
    let tdefs'' = fmap (second (map snd)) tdefs'


@@ 43,113 43,96 @@ typecheck (Ast.Program defs tdefs externs) = runExcept $ do
    checkStartDefined ds =
        when (not (Map.member "start" ds)) (throwError StartNotDefined)

type CheckTypeDefs a
    = ReaderT
          (Map String Int)
          (StateT (An.TypeDefs, An.Ctors) (Except TypeErr))
          a

checkTypeDefs :: [Ast.TypeDef] -> Except TypeErr (An.TypeDefs, An.Ctors)
checkTypeDefs tdefs = do
    (tdefs', ctors) <- checkTypeDefsNoConflicting tdefs
    let tdefs'' = fmap (second (map snd)) tdefs'
    forM_ (Map.toList tdefs')
        $ \tdef -> checkTConstsDefs tdefs'' tdef *> assertNoRec tdefs' tdef
    pure (tdefs'', ctors)
  where
    -- | Check that constructurs don't refer to undefined types and that TConsts
    --   are of correct arity.
    checkTConstsDefs tds (_, (_, cs)) = forM_ cs (checkTConstsCtor tds)
    checkTConstsCtor tds (cpos, (_, ts)) = forM_ ts (checkType' tds cpos)
    -- | Check that type definitions are not recursive without indirection and
    --   that constructors don't refer to undefined types.
    assertNoRec tds (x, (_, cs)) = assertNoRecCtors tds x Map.empty cs
    assertNoRecCtors tds x s =
        mapM_ $ \(cpos, (_, ts)) ->
            forM_ ts (assertNoRecType tds x cpos . subst s)
    assertNoRecType tds x cpos = \case
        TVar _ -> pure ()
        TPrim _ -> pure ()
        TConst (y, ts) -> do
            when (x == y) $ throwError (RecTypeDef x cpos)
            let (tvs, cs) = tds Map.! y
            let substs = Map.fromList (zip tvs ts)
            assertNoRecCtors tds x substs cs
        TFun _ _ -> pure ()
        TBox _ -> pure ()

-- | Check that there are no conflicting type names or constructor names.
checkTypeDefsNoConflicting
    :: [Ast.TypeDef]
    -> Except
           TypeErr
           ( Map String ([TVar], [(SrcPos, (String, [Type]))])
           , Map String (VariantIx, (String, [TVar]), [Type], Span)
           )
checkTypeDefsNoConflicting =
    flip foldM (builtinDataTypes, builtinConstructors)
        $ \(tds', csAcc) td@(Ast.TypeDef x _ _) -> do
            when (Map.member (idstr x) tds') (throwError (ConflictingTypeDef x))
            (td', cs) <- checkTypeDef td
            case listToMaybe (Map.elems (Map.intersection cs csAcc)) of
                Just (cId, _) -> throwError (ConflictingCtorDef cId)
                Nothing ->
                    pure
                        ( uncurry Map.insert td' tds'
                        , Map.union (fmap snd cs) csAcc
                        )

checkTypeDef
    :: Ast.TypeDef
    -> Except
           TypeErr
           ( (String, ([TVar], [(SrcPos, (String, [Type]))]))
           , Map
                 String
                 (Id 'Big, (VariantIx, (String, [TVar]), [Type], Span))
           )
checkTypeDef (Ast.TypeDef x' ps (Ast.ConstructorDefs cs)) = do
    let x = idstr x'
    let tdefsParams =
            Map.union (fmap (length . fst) builtinDataTypes) $ Map.fromList
                (map (\(Ast.TypeDef x ps _) -> (idstr x, length ps)) tdefs)
    (tdefs', ctors) <- execStateT
        (runReaderT (forM_ tdefs checkTypeDef) tdefsParams)
        (builtinDataTypes, builtinConstructors)
    forM_ (Map.toList tdefs') (assertNoRec tdefs')
    pure (tdefs', ctors)

checkTypeDef :: Ast.TypeDef -> CheckTypeDefs ()
checkTypeDef (Ast.TypeDef (Ast.Id (WithPos xpos x)) ps cs) = do
    tAlreadyDefined <- gets (Map.member x . fst)
    when tAlreadyDefined (throwError (ConflictingTypeDef xpos x))
    let ps' = map TVExplicit ps
    let cs' = map (\(Id (WithPos p y), ts) -> (p, (y, ts))) cs
    let cSpan = fromIntegral (length cs)
    cs''' <- foldM
        (\cs'' (i, (cx, cps)) -> if Map.member (idstr cx) cs''
            then throwError (ConflictingCtorDef cx)
            else pure
                (Map.insert (idstr cx) (cx, (i, (x, ps'), cps, cSpan)) cs'')
        )
        Map.empty
        (zip [0 ..] cs)
    pure ((x, (ps', cs')), cs''')

builtinDataTypes :: Map String ([TVar], [(SrcPos, (String, [Type]))])
builtinDataTypes = Map.fromList
    (map (\(x, ps, cs) -> (x, (ps, map (dummyPos, ) cs))) builtinDataTypes')

builtinConstructors :: Map String (VariantIx, (String, [TVar]), [Type], Span)
    cs' <- checkCtors (x, ps') cs
    modify (first (Map.insert x (ps', cs')))

checkCtors
    :: (String, [TVar])
    -> Ast.ConstructorDefs
    -> CheckTypeDefs [(An.Id, [An.Type])]
checkCtors parent (Ast.ConstructorDefs cs) =
    let cspan = fromIntegral (length cs)
    in mapM (checkCtor cspan) (zip [0 ..] cs)
  where
    checkCtor cspan (i, (Id c'@(WithPos pos c), ts)) = do
        cAlreadyDefined <- gets (Map.member c . snd)
        when cAlreadyDefined (throwError (ConflictingCtorDef pos c))
        ts' <- mapM checkType ts
        modify (second (Map.insert c (i, parent, ts', cspan)))
        pure (c', ts')
    checkType t =
        ask >>= \tdefs -> checkType' (\x -> Map.lookup x tdefs) pos t

builtinDataTypes :: An.TypeDefs
builtinDataTypes = Map.fromList $ map
    (\(x, ps, cs) -> (x, (ps, map (first (WithPos dummyPos)) cs)))
    builtinDataTypes'

builtinConstructors :: An.Ctors
builtinConstructors = Map.unions (map builtinConstructors' builtinDataTypes')
  where
    builtinConstructors' (x, ps, cs) =
        let cSpan = fromIntegral (length cs)
        in
            foldl'
                (\csAcc (i, (cx, cps)) ->
                    Map.insert cx (i, (x, ps), cps, cSpan) csAcc
                )
                Map.empty
                (zip [0 ..] cs)

builtinConstructors'
    :: (String, [TVar], [(String, [Type])])
    -> Map String (VariantIx, (String, [TVar]), [Type], Span)
builtinConstructors' (x, ps, cs) =
    let cSpan = fromIntegral (length cs)
    in
        foldl'
            (\csAcc (i, (cx, cps)) ->
                Map.insert cx (i, (x, ps), cps, cSpan) csAcc
            )
            Map.empty
            (zip [0 ..] cs)

builtinDataTypes' :: [(String, [TVar], [(String, [Type])])]
builtinDataTypes' :: [(String, [TVar], [(String, [An.Type])])]
builtinDataTypes' =
    [ ( "Array"
      , [TVImplicit 0]
      , [("Array", [TBox (TVar (TVImplicit 0)), TPrim TNat])]
      , [("Array", [An.TBox (An.TVar (TVImplicit 0)), An.TPrim TNat])]
      )
    , ("Str", [], [("Str", [TConst ("Array", [TPrim TNat8])])])
    , ("Str", [], [("Str", [An.TConst ("Array", [An.TPrim TNat8])])])
    , ( "Pair"
      , [TVImplicit 0, TVImplicit 1]
      , [("Pair", [TVar (TVImplicit 0), TVar (TVImplicit 1)])]
      , [("Pair", [An.TVar (TVImplicit 0), An.TVar (TVImplicit 1)])]
      )
    ]

assertNoRec
    :: An.TypeDefs
    -> (String, ([TVar], [(An.Id, [An.Type])]))
    -> Except TypeErr ()
assertNoRec tdefs' (x, (_, ctors)) = assertNoRec' ctors Map.empty
  where
    assertNoRec' cs s =
        forM_ cs $ \(WithPos cpos _, cts) ->
            forM_ cts (assertNoRecType cpos . subst s)
    assertNoRecType cpos = \case
        An.TConst (y, ts) -> do
            when (x == y) $ throwError (RecTypeDef x cpos)
            let (tvs, cs) = tdefs' Map.! y
            let substs = Map.fromList (zip tvs ts)
            assertNoRec' cs substs
        _ -> pure ()

type Bound = ReaderT (Set TVar) (Except TypeErr) ()

-- TODO: Many of these positions are weird and kind of arbitrary, man. They may


@@ 186,13 169,13 @@ checkTypeVarsBound ds = runReaderT (boundInDefs ds) Set.empty
        An.Deref x -> boundInExpr x
    boundInType :: SrcPos -> An.Type -> Bound
    boundInType pos = \case
        TVar tv -> do
        An.TVar tv -> do
            bound <- ask
            when (not (Set.member tv bound)) (throwError (UnboundTVar pos))
        TPrim _ -> pure ()
        TConst (_, ts) -> forM_ ts (boundInType pos)
        TFun ft at -> forM_ [ft, at] (boundInType pos)
        TBox t -> boundInType pos t
        An.TPrim _ -> pure ()
        An.TConst (_, ts) -> forM_ ts (boundInType pos)
        An.TFun ft at -> forM_ [ft, at] (boundInType pos)
        An.TBox t -> boundInType pos t
    boundInCases cs = forM_ cs (bimapM boundInPat boundInExpr)
    boundInPat (WithPos pos pat) = case pat of
        An.PVar (An.TypedVar _ t) -> boundInType pos t


@@ 208,7 191,7 @@ compileDecisionTreesAndDesugar tdefs = compDefs
    compDefs = mapM compDef
    compDef = bimapM pure compExpr
    compExpr :: An.Expr -> Except TypeErr Des.Expr
    compExpr (WithPos pos e) = case e of
    compExpr (WithPos pos expr) = case expr of
        An.Lit c -> pure (Des.Lit c)
        An.Var (An.TypedVar (WithPos _ x) t) ->
            pure (Des.Var (Des.TypedVar x t))


@@ 233,8 216,10 @@ compileDecisionTreesAndDesugar tdefs = compDefs
                params = zip xs ts
                args = map (Des.Var . uncurry Des.TypedVar) params
            in pure $ snd $ foldr
                (\(p, pt) (bt, b) -> (TFun pt bt, Des.Fun (p, pt) (b, bt)))
                (TConst inst, Des.Ction v span' inst args)
                (\(p, pt) (bt, b) ->
                    (An.TFun pt bt, Des.Fun (p, pt) (b, bt))
                )
                (An.TConst inst, Des.Ction v span' inst args)
                params
        An.Box x -> fmap Des.Box (compExpr x)
        An.Deref x -> fmap Des.Deref (compExpr x)

M src/Codegen.hs => src/Codegen.hs +1 -1
@@ 29,7 29,7 @@ import Control.Applicative
import Lens.Micro.Platform (modifying, use, assign, to, view)

import Misc
import PrettyAst ()
import Pretty
import FreeVars
import qualified MonoAst
import MonoAst hiding (Type, Const)

M src/Infer.hs => src/Infer.hs +44 -34
@@ 21,7 21,7 @@ import SrcPos
import FreeVars
import Subst
import qualified Ast
import Ast (Id(..), IdCase(..), idstr, scmBody, isFunLike)
import Ast (Id(..), IdCase(..), idstr, isFunLike)
import TypeErr
import AnnotAst hiding (Id)



@@ 69,19 69,9 @@ inferTopDefs tdefs ctors externs defs =
    inferTopDefs' = do
        externs' <- checkExterns externs
        let externs'' = fmap (Forall Set.empty) externs'
        defs' <- checkStartType defs
        defs'' <- augment envDefs externs'' (inferDefs defs')
        defs'' <- augment envDefs externs'' (inferDefs defs)
        s <- use substs
        pure (externs', defs'', s)
    checkStartType = \case
        (x@(Id (WithPos _ "start")), (s, b)) : ds ->
            if s == Nothing || unpos (fromJust s) == startScheme
                then pure
                    ((x, (Just (WithPos dummyPos startScheme), b)) : ds)
                else throwError (WrongStartType (fromJust s))
        d : ds -> fmap (d :) (checkStartType ds)
        [] -> pure []
    startScheme = Forall Set.empty startType

-- TODO: Check that the types of the externs are valid more than just not
--       containing type vars. E.g., they may not refer to undefined types, duh.


@@ 97,21 87,29 @@ checkExterns = fmap Map.fromList . mapM checkExtern
checkType :: SrcPos -> Ast.Type -> Infer Type
checkType pos t = do
    tds <- view envTypeDefs
    lift (lift (checkType' tds pos t))

checkType' :: TypeDefs -> SrcPos -> Ast.Type -> Except TypeErr Type
checkType' tds pos = \case
    TVar v -> pure (TVar v)
    TPrim p -> pure (TPrim p)
    TConst tc -> fmap TConst (checkTConst tc)
    TFun f a -> liftA2 TFun (checkType' tds pos f) (checkType' tds pos a)
    TBox t -> fmap TBox (checkType' tds pos t)
    checkType' (\x -> fmap (length . fst) (Map.lookup x tds)) pos t

checkType'
    :: MonadError TypeErr m
    => (String -> Maybe Int)
    -> SrcPos
    -> Ast.Type
    -> m Type
checkType' tdefsParams pos = checkType''
  where
    checkTConst (x, inst) = case Map.lookup x tds of
        Just (tvs, _) -> do
            let (expectedN, foundN) = (length tvs, length inst)
    checkType'' = \case
        Ast.TVar v -> pure (TVar v)
        Ast.TPrim p -> pure (TPrim p)
        Ast.TConst tc -> fmap TConst (checkTConst tc)
        Ast.TFun f a -> liftA2 TFun (checkType'' f) (checkType'' a)
        Ast.TBox t -> fmap TBox (checkType'' t)
    checkTConst (x, inst) = case tdefsParams x of
        Just expectedN -> do
            let foundN = length inst
            if (expectedN == foundN)
                then pure (x, inst)
                then do
                    inst' <- mapM checkType'' inst
                    pure (x, inst')
                else throwError
                    (TypeInstArityMismatch pos x expectedN foundN)
        Nothing -> throwError (UndefType pos x)


@@ 150,9 148,8 @@ inferDefsComponents = \case
                CyclicSCC verts' -> (verts', True)
        let (idents, rhss) = unzip verts
        let (mayscms, bodies) = unzip rhss
        checkUserSchemes (catMaybes mayscms)
        let mayscms' = map (fmap unpos) mayscms
        let names = map idstr idents
        mayscms' <- mapM checkScheme (zip names mayscms)
        ts <- replicateM (length names) fresh
        let scms = map
                (\(mayscm, t) -> fromMaybe (Forall Set.empty t) mayscm)


@@ 174,10 171,20 @@ inferDefsComponents = \case
        pure (Map.union annotRest annotDefs)

-- | Verify that user-provided type signature schemes are valid
checkUserSchemes :: [WithPos Scheme] -> Infer ()
checkUserSchemes scms = forM_ scms $ \(WithPos p s1@(Forall _ t)) ->
    generalize t
        >>= \s2 -> when (s1 /= s2) (throwError (InvalidUserTypeSig p s1 s2))
checkScheme :: (String, Maybe Ast.Scheme) -> Infer (Maybe Scheme)
checkScheme = \case
    ("start", Nothing) -> pure (Just (Forall Set.empty startType))
    ("start", Just s@(Ast.Forall pos vs t))
        | Set.size vs /= 0 || t /= Ast.startType -> throwError
            (WrongStartType pos s)
    (_, Nothing) -> pure Nothing
    (_, Just (Ast.Forall pos vs t)) -> do
        t' <- checkType pos t
        let s1 = Forall vs t'
        s2 <- generalize t'
        if (s1 == s2)
            then pure (Just s1)
            else throwError (InvalidUserTypeSig pos s1 s2)

infer :: Ast.Expr -> Infer (Type, Expr)
infer (WithPos pos e) = fmap (second (WithPos pos)) $ case e of


@@ 207,8 214,9 @@ infer (WithPos pos e) = fmap (second (WithPos pos)) $ case e of
        pure (bt, Let annotDefs b')
    Ast.TypeAscr x t -> do
        (tx, WithPos _ x') <- infer x
        unify (Expected t) (Found (getPos x) tx)
        pure (t, x')
        t' <- checkType pos t
        unify (Expected t') (Found (getPos x) tx)
        pure (t', x')
    Ast.Match matchee cases -> do
        (tmatchee, matchee') <- infer matchee
        (tbody, cases') <- inferCases (Expected tmatchee) cases


@@ 295,7 303,9 @@ inferPatConstruction pos c cArgs = do
    (variantIx, tdefLhs, cParams, cSpan) <- lookupEnvConstructor c
    let arity = length cParams
    let nArgs = length cArgs
    unless (arity == nArgs) (throwError (CtorArityMismatch pos c arity nArgs))
    unless
        (arity == nArgs)
        (throwError (CtorArityMismatch pos (idstr c) arity nArgs))
    (tdefInst, cParams') <- instantiateConstructorOfTypeDef tdefLhs cParams
    let t = TConst tdefInst
    (cArgTs, cArgs', cArgsVars) <- fmap unzip3 (mapM inferPat cArgs)

M src/Match.hs => src/Match.hs +1 -0
@@ 21,6 21,7 @@ import Data.Word
import Lens.Micro.Platform (makeLenses, view, to)

import Misc hiding (augment)
import Pretty
import SrcPos
import TypeErr
import qualified AnnotAst as An

M src/Misc.hs => src/Misc.hs +0 -9
@@ 4,8 4,6 @@ module Misc
    ( ice
    , nyi
    , precalate
    , pretty
    , Pretty(..)
    , indent
    , both
    , secondM


@@ 48,13 46,6 @@ precalate prefix = \case
    [] -> []
    xs -> prefix ++ intercalate prefix xs

pretty :: Pretty a => a -> String
pretty = pretty' 0

-- Pretty print starting at some indentation depth
class Pretty a where
    pretty' :: Int -> a -> String

indent :: Int -> String
indent = flip replicate ' '


M src/Parse.hs => src/Parse.hs +11 -10
@@ 153,15 153,15 @@ def :: SrcPos -> Parser Def
def topPos = defUntyped topPos <|> defTyped topPos

defUntyped :: SrcPos -> Parser Def
defUntyped = (reserved "define" *>) . def' (pure Nothing)
defUntyped pos = reserved "define" *> def' (pure Nothing) pos

defTyped :: SrcPos -> Parser Def
defTyped = (reserved "define:" *>) . def' (fmap Just scheme)
defTyped pos = reserved "define:" *> def' (fmap Just scheme) pos

def'
    :: Parser (Maybe (WithPos Scheme))
    :: Parser (Maybe Scheme)
    -> SrcPos
    -> Parser (Id 'Small, (Maybe (WithPos Scheme), Expr))
    -> Parser (Id 'Small, (Maybe Scheme, Expr))
def' schemeParser topPos = varDef <|> funDef
  where
    varDef = do


@@ 247,12 247,13 @@ pat = choice [patInt, patBool, patStr, patCtor, patVar, ppat]
    patBox pos = reserved "Box" *> fmap (PBox pos) pat
    patCtion pos = liftM3 PConstruction (pure pos) big' (some pat)

scheme :: Parser (WithPos Scheme)
scheme = withPos $ wrap nonptype <|> (parens (universal <|> wrap ptype))
  where
    wrap = fmap (Forall Set.empty)
    universal = reserved "forall" *> liftA2 Forall tvars type_
    tvars = parens (fmap Set.fromList (many tvar))
scheme :: Parser Scheme
scheme = do
    pos <- getSrcPos
    let wrap = fmap (Forall pos Set.empty)
        universal = reserved "forall" *> liftA2 (Forall pos) tvars type_
        tvars = parens (fmap Set.fromList (many tvar))
    wrap nonptype <|> (parens (universal <|> wrap ptype))

type_ :: Parser Type
type_ = nonptype <|> parens ptype

R src/PrettyAst.hs => src/Pretty.hs +122 -79
@@ 1,44 1,61 @@
{-# LANGUAGE LambdaCase #-}

module PrettyAst () where
module Pretty (pretty, Pretty(..)) where

import Prelude hiding (showChar)
import Data.List
import Data.Bifunctor
import qualified Data.Set as Set
import Data.Set (Set)

import Misc
import SrcPos
import Ast
import qualified Ast
import qualified AnnotAst as An


instance Pretty Program where
-- Pretty print starting at some indentation depth
class Pretty a where
    pretty' :: Int -> a -> String

pretty :: Pretty a => a -> String
pretty = pretty' 0

spcPretty :: Pretty a => [a] -> String
spcPretty = unwords . map pretty


instance Pretty a => Pretty (WithPos a) where
    pretty' d = pretty' d . unpos


instance Pretty Ast.Program where
    pretty' = prettyProg
instance Pretty Extern where
instance Pretty Ast.Extern where
    pretty' = prettyExtern
instance Pretty ConstructorDefs where
instance Pretty Ast.ConstructorDefs where
    pretty' = prettyConstructorDefs
instance Pretty TypeDef where
instance Pretty Ast.TypeDef where
    pretty' = prettyTypeDef
instance Pretty Expr' where
instance Pretty Ast.Expr' where
    pretty' = prettyExpr'
instance Pretty Pat where
instance Pretty Ast.Pat where
    pretty' _ = prettyPat
instance Pretty Const where
instance Pretty Ast.Const where
    pretty' _ = prettyConst
instance Pretty Scheme where
    pretty' _ = prettyScheme
instance Pretty Type where
instance Pretty Ast.Scheme where
    pretty' _ (Ast.Forall _ ps t) = prettyScheme ps t
instance Pretty Ast.Type where
    pretty' _ = prettyType
instance Pretty TPrim where
instance Pretty Ast.TPrim where
    pretty' _ = prettyTPrim
instance Pretty TVar where
instance Pretty Ast.TVar where
    pretty' _ = prettyTVar
instance Pretty (Id a) where
    pretty' _ = idstr

instance Pretty (Ast.Id a) where
    pretty' _ = Ast.idstr

prettyProg :: Int -> Program -> String
prettyProg d (Program defs tdefs externs) =
prettyProg :: Int -> Ast.Program -> String
prettyProg d (Ast.Program defs tdefs externs) =
    let
        prettyDef = \case
            (name, (Just scm, body)) -> concat


@@ 52,12 69,12 @@ prettyProg d (Program defs tdefs externs) =
                ]
    in unlines (map prettyDef defs ++ map pretty tdefs ++ map pretty externs)

prettyExtern :: Int -> Extern -> String
prettyExtern _ (Extern name t) =
    concat ["(extern ", idstr name, " ", pretty t, ")"]
prettyExtern :: Int -> Ast.Extern -> String
prettyExtern _ (Ast.Extern name t) =
    concat ["(extern ", Ast.idstr name, " ", pretty t, ")"]

prettyTypeDef :: Int -> TypeDef -> String
prettyTypeDef d (TypeDef name params constrs) = concat
prettyTypeDef :: Int -> Ast.TypeDef -> String
prettyTypeDef d (Ast.TypeDef name params constrs) = concat
    [ "(type "
    , if null params
        then pretty name


@@ 65,8 82,8 @@ prettyTypeDef d (TypeDef name params constrs) = concat
    , "\n" ++ indent (d + 2) ++ pretty' (d + 2) constrs ++ ")"
    ]

prettyConstructorDefs :: Int -> ConstructorDefs -> String
prettyConstructorDefs d (ConstructorDefs cs) = intercalate
prettyConstructorDefs :: Int -> Ast.ConstructorDefs -> String
prettyConstructorDefs d (Ast.ConstructorDefs cs) = intercalate
    ("\n" ++ indent d)
    (map prettyConstrDef cs)
  where


@@ 74,20 91,20 @@ prettyConstructorDefs d (ConstructorDefs cs) = intercalate
        (c, []) -> pretty c
        (c, ts) -> concat ["(", pretty c, " ", spcPretty ts, ")"]

prettyExpr' :: Int -> Expr' -> String
prettyExpr' :: Int -> Ast.Expr' -> String
prettyExpr' d = \case
    Lit l -> pretty l
    Var v -> idstr v
    App f x -> concat
    Ast.Lit l -> pretty l
    Ast.Var v -> Ast.idstr v
    Ast.App f x -> concat
        [ "(" ++ pretty' (d + 1) f ++ "\n"
        , indent (d + 1) ++ pretty' (d + 1) x ++ ")"
        ]
    If pred' cons alt -> concat
    Ast.If pred' cons alt -> concat
        [ "(if " ++ pretty' (d + 4) pred' ++ "\n"
        , indent (d + 4) ++ pretty' (d + 4) cons ++ "\n"
        , indent (d + 2) ++ pretty' (d + 2) alt ++ ")"
        ]
    Fun param body -> concat
    Ast.Fun param body -> concat
        [ "(fun ("
        , prettyPat param
        , ")\n"


@@ 95,7 112,7 @@ prettyExpr' d = \case
        , pretty' (d + 2) body
        , ")"
        ]
    Let binds body -> concat
    Ast.Let binds body -> concat
        [ "(let ["
        , intercalate ("\n" ++ indent (d + 6)) (map (prettyDef (d + 6)) binds)
        , "]\n"


@@ 112,47 129,47 @@ prettyExpr' d = \case
                [ "[" ++ pretty' (d' + 1) name ++ "\n"
                , indent (d' + 1) ++ pretty' (d' + 1) dbody ++ "]"
                ]
    TypeAscr e t ->
    Ast.TypeAscr e t ->
        concat ["(: ", pretty' (d + 3) e, "\n", pretty' (d + 3) t, ")"]
    Match e cs -> concat
    Ast.Match e cs -> concat
        [ "(match " ++ pretty' (d + 7) e
        , precalate
            ("\n" ++ indent (d + 2))
            (map (prettyBracketPair (d + 2)) cs)
        , ")"
        ]
    FunMatch cs -> concat
    Ast.FunMatch cs -> concat
        [ "(fun-match"
        , precalate
            ("\n" ++ indent (d + 2))
            (map (prettyBracketPair (d + 2)) cs)
        , ")"
        ]
    Ctor c -> pretty c
    Box e -> concat ["(box ", pretty' (d + 5) e, ")"]
    Deref e -> concat ["(deref ", pretty' (d + 7) e, ")"]
    Ast.Ctor c -> pretty c
    Ast.Box e -> concat ["(box ", pretty' (d + 5) e, ")"]
    Ast.Deref e -> concat ["(deref ", pretty' (d + 7) e, ")"]

prettyBracketPair :: (Pretty a, Pretty b) => Int -> (a, b) -> String
prettyBracketPair d (a, b) = concat
    ["[", pretty' (d + 1) a, "\n", indent (d + 1), pretty' (d + 1) b, "]"]

prettyPat :: Pat -> String
prettyPat :: Ast.Pat -> String
prettyPat = \case
    PConstruction _ (Id (WithPos _ c)) ps ->
    Ast.PConstruction _ (Ast.Id (WithPos _ c)) ps ->
        if null ps then c else concat ["(", c, " ", spcPretty ps, ")"]
    PInt _ n -> show n
    PBool _ b -> if b then "true" else "false"
    PStr _ s -> prettyStr s
    PVar v -> idstr v
    PBox _ p -> "(Box " ++ prettyPat p ++ ")"
    Ast.PInt _ n -> show n
    Ast.PBool _ b -> if b then "true" else "false"
    Ast.PStr _ s -> prettyStr s
    Ast.PVar v -> Ast.idstr v
    Ast.PBox _ p -> "(Box " ++ prettyPat p ++ ")"

prettyConst :: Const -> String
prettyConst :: Ast.Const -> String
prettyConst = \case
    Unit -> "unit"
    Int n -> show n
    Double x -> show x
    Str s -> prettyStr s
    Bool b -> if b then "true" else "false"
    Ast.Unit -> "unit"
    Ast.Int n -> show n
    Ast.Double x -> show x
    Ast.Str s -> prettyStr s
    Ast.Bool b -> if b then "true" else "false"

prettyStr :: String -> String
prettyStr s = '"' : (s >>= showChar) ++ "\""


@@ 170,47 187,73 @@ prettyStr s = '"' : (s >>= showChar) ++ "\""
        '\"' -> "\\\""
        c -> [c]

prettyScheme :: Scheme -> String
prettyScheme (Forall ps t) =
prettyScheme :: (Pretty p, Pretty t) => Set p -> t -> String
prettyScheme ps t =
    concat ["(forall [" ++ spcPretty (Set.toList ps) ++ "] ", pretty t ++ ")"]

prettyType :: Type -> String
prettyType :: Ast.Type -> String
prettyType = \case
    Ast.TVar tv -> pretty tv
    Ast.TPrim c -> pretty c
    Ast.TFun a b -> prettyTFun a b
    Ast.TBox t -> "(Box " ++ pretty t ++ ")"
    Ast.TConst (c, ts) -> case ts of
        [] -> c
        _ -> concat ["(", c, " ", spcPretty ts, ")"]
    Ast.TBox t -> prettyTBox t
    Ast.TConst tc -> prettyTConst tc

prettyTConst :: Pretty t => (String, [t]) -> String
prettyTConst (c, ts) = case ts of
    [] -> c
    _ -> concat ["(", c, " ", spcPretty ts, ")"]

prettyTBox :: Pretty t => t -> String
prettyTBox t = "(Box " ++ pretty t ++ ")"

prettyTFun :: Type -> Type -> String
prettyTFun :: Ast.Type -> Ast.Type -> String
prettyTFun a b =
    let
        (bParams, bBody) = f b
        f = \case
            TFun a' b' -> first (a' :) (f b')
            Ast.TFun a' b' -> first (a' :) (f b')
            t -> ([], t)
    in concat ["(Fun ", pretty a, " ", spcPretty (bParams ++ [bBody]), ")"]

prettyTPrim :: TPrim -> String
prettyTPrim :: Ast.TPrim -> String
prettyTPrim = \case
    TUnit -> "Unit"
    TNat8 -> "Nat8"
    TNat16 -> "Nat16"
    TNat32 -> "Nat32"
    TNat -> "Nat"
    TInt8 -> "Int8"
    TInt16 -> "Int16"
    TInt32 -> "Int32"
    TInt -> "Int"
    TDouble -> "Double"
    TBool -> "Bool"

prettyTVar :: TVar -> String
    Ast.TUnit -> "Unit"
    Ast.TNat8 -> "Nat8"
    Ast.TNat16 -> "Nat16"
    Ast.TNat32 -> "Nat32"
    Ast.TNat -> "Nat"
    Ast.TInt8 -> "Int8"
    Ast.TInt16 -> "Int16"
    Ast.TInt32 -> "Int32"
    Ast.TInt -> "Int"
    Ast.TDouble -> "Double"
    Ast.TBool -> "Bool"

prettyTVar :: Ast.TVar -> String
prettyTVar = \case
    TVExplicit v -> idstr v
    TVImplicit n -> "#" ++ show n
    Ast.TVExplicit v -> Ast.idstr v
    Ast.TVImplicit n -> "#" ++ show n

spcPretty :: Pretty a => [a] -> String
spcPretty = unwords . map pretty

instance Pretty An.Scheme where
    pretty' _ (An.Forall ps t) = prettyScheme ps t
instance Pretty An.Type where
    pretty' _ = prettyAnType

prettyAnType :: An.Type -> String
prettyAnType = \case
    An.TVar tv -> pretty tv
    An.TPrim c -> pretty c
    An.TFun a b -> prettyAnTFun a b
    An.TBox t -> prettyTBox t
    An.TConst tc -> prettyTConst tc

prettyAnTFun :: An.Type -> An.Type -> String
prettyAnTFun a b =
    let
        (bParams, bBody) = f b
        f = \case
            An.TFun a' b' -> first (a' :) (f b')
            t -> ([], t)
    in concat ["(Fun ", pretty a, " ", spcPretty (bParams ++ [bBody]), ")"]

M src/SrcPos.hs => src/SrcPos.hs +0 -3
@@ 10,7 10,6 @@ where

import Text.Megaparsec.Pos

import Misc

newtype SrcPos = SrcPos SourcePos
    deriving (Show, Eq)


@@ 27,8 26,6 @@ instance Eq a => Eq (WithPos a) where
    (WithPos _ a) == (WithPos _ b) = a == b
instance Ord a => Ord (WithPos a) where
    compare (WithPos _ a) (WithPos _ b) = compare a b
instance Pretty a => Pretty (WithPos a) where
    pretty' d = pretty' d . unpos

instance HasPos (WithPos a) where
    getPos (WithPos p _) = p

M src/TypeErr.hs => src/TypeErr.hs +13 -12
@@ 6,30 6,31 @@ import Text.Megaparsec (SourcePos(..), unPos)

import Misc
import SrcPos
import Ast
import PrettyAst ()
import qualified Ast
import AnnotAst
import Pretty
import Parse


data TypeErr
    = StartNotDefined
    | InvalidUserTypeSig SrcPos Scheme Scheme
    | CtorArityMismatch SrcPos (Id 'Big) Int Int
    | CtorArityMismatch SrcPos String Int Int
    | ConflictingPatVarDefs SrcPos String
    | UndefCtor SrcPos String
    | UndefVar SrcPos String
    | InfType SrcPos Type Type TVar Type
    | UnificationFailed SrcPos Type Type Type Type
    | ConflictingTypeDef (Id 'Big)
    | ConflictingCtorDef (Id 'Big)
    | ConflictingTypeDef SrcPos String
    | ConflictingCtorDef SrcPos String
    | RedundantCase SrcPos
    | InexhaustivePats SrcPos String
    | ExternNotMonomorphic (Id 'Small) TVar
    | ExternNotMonomorphic (Ast.Id 'Ast.Small) TVar
    | FoundHole SrcPos
    | RecTypeDef String SrcPos
    | UndefType SrcPos String
    | UnboundTVar SrcPos
    | WrongStartType (WithPos Scheme)
    | WrongStartType SrcPos Ast.Scheme
    | RecursiveVarDef (WithPos String)
    | TypeInstArityMismatch SrcPos String Int Int
    | ConflictingVarDef SrcPos String


@@ 46,7 47,7 @@ printErr = \case
            ++ (", expected " ++ pretty s2)
    CtorArityMismatch p c arity nArgs ->
        posd p
            $ ("Arity mismatch for constructor `" ++ pretty c)
            $ ("Arity mismatch for constructor `" ++ c)
            ++ ("` in pattern.\nExpected " ++ show arity)
            ++ (", found " ++ show nArgs)
    ConflictingPatVarDefs p v ->


@@ 67,15 68,15 @@ printErr = \case
            $ ("Couldn't match type " ++ pretty t'2 ++ " with " ++ pretty t'1)
            ++ (".\nExpected type: " ++ pretty t1)
            ++ (".\nFound type: " ++ pretty t2 ++ ".")
    ConflictingTypeDef (Id (WithPos p x)) ->
    ConflictingTypeDef p x ->
        posd p $ "Conflicting definitions for type `" ++ x ++ "`."
    ConflictingCtorDef (Id (WithPos p x)) ->
    ConflictingCtorDef p x ->
        posd p $ "Conflicting definitions for constructor `" ++ x ++ "`."
    RedundantCase p -> posd p $ "Redundant case in pattern match."
    InexhaustivePats p patStr ->
        posd p $ "Inexhaustive patterns: " ++ patStr ++ " not covered."
    ExternNotMonomorphic name tv -> case tv of
        TVExplicit (Id (WithPos p tv')) ->
        TVExplicit (Ast.Id (WithPos p tv')) ->
            posd p
                $ ("Extern " ++ pretty name ++ " is not monomorphic. ")
                ++ ("Type variable " ++ tv' ++ " encountered in type signature")


@@ 91,7 92,7 @@ printErr = \case
        posd p
            $ "Could not fully infer type of expression.\n"
            ++ "Type annotations needed."
    WrongStartType (WithPos p s) ->
    WrongStartType p s ->
        posd p
            $ "Incorrect type of `start`.\n"
            ++ ("Expected: " ++ pretty startType)