~jojo/Carth

ba3f58733ff8d99d87ac8bbdef4cb445e470a9e8 — JoJo 1 year, 10 months ago 7f298ae
Check definedness & arity of TConsts in type defs

Thought I checked definedness before within assertNoRec, but that
check was only for "top-level" tconsts within a constructor.
3 files changed, 51 insertions(+), 20 deletions(-)

M src/Check.hs
M src/Parse.hs
M src/TypeErr.hs
M src/Check.hs => src/Check.hs +25 -9
@@ 42,9 42,27 @@ typecheck (Ast.Program defs tdefs externs) = runExcept $ do
checkTypeDefs :: [Ast.TypeDef] -> Except TypeErr (An.TypeDefs, An.Ctors)
checkTypeDefs tdefs = do
    (tdefs', ctors) <- checkTypeDefsNoConflicting tdefs
    forM_ (Map.toList tdefs') (assertNoRec tdefs')
    pure (fmap (second (map snd)) tdefs', ctors)
    let tdefs'' = fmap (second (map snd)) tdefs'
    forM_ (Map.toList tdefs')
        $ \tdef -> checkTConstsDefs tdefs'' tdef *> assertNoRec tdefs' tdef
    pure (tdefs'', ctors)
  where
    -- | Check that constructurs don't refer to undefined types and that TConsts
    --   are of correct arity.
    checkTConstsDefs tds (_, (_, cs)) = forM_ cs (checkTConstsCtor tds)
    checkTConstsCtor tds (cpos, (_, ts)) = forM_ ts (checkType tds cpos)
    checkType tds cpos = \case
        TVar _ -> pure ()
        TPrim _ -> pure ()
        TConst tc -> checkTConst tds cpos tc
        TFun f a -> checkType tds cpos f *> checkType tds cpos a
        TBox t -> checkType tds cpos t
    checkTConst tds cpos (x, inst) = case Map.lookup x tds of
        Just (tvs, _) -> do
            let (expectedN, foundN) = (length tvs, length inst)
            when (not (expectedN == foundN)) $ throwError
                (TypeInstArityMismatch cpos x expectedN foundN)
        Nothing -> throwError (UndefType cpos x)
    -- | Check that type definitions are not recursive without indirection and
    --   that constructors don't refer to undefined types.
    assertNoRec tds (x, (_, cs)) = assertNoRecCtors tds x Map.empty cs


@@ 54,13 72,11 @@ checkTypeDefs tdefs = do
    assertNoRecType tds x cpos = \case
        TVar _ -> pure ()
        TPrim _ -> pure ()
        TConst (y, ts) -> if x == y
            then throwError (RecTypeDef x cpos)
            else case Map.lookup y tds of
                Just (tvs, cs) ->
                    let substs = Map.fromList (zip tvs ts)
                    in assertNoRecCtors tds x substs cs
                Nothing -> throwError (UndefType cpos y)
        TConst (y, ts) -> do
            when (x == y) $ throwError (RecTypeDef x cpos)
            let (tvs, cs) = tds Map.! y
            let substs = Map.fromList (zip tvs ts)
            assertNoRecCtors tds x substs cs
        TFun _ _ -> pure ()
        TBox _ -> pure ()


M src/Parse.hs => src/Parse.hs +18 -10
@@ 22,6 22,7 @@ module Parse
    , ns_parens
    , def
    , getSrcPos
    , ns_tokenTree
    )
where



@@ 109,6 110,13 @@ parseModule filepath dir m visiteds nexts = do
parse' :: Parser a -> FilePath -> Source -> Either String a
parse' p name src = mapLeft errorBundlePretty (Mega.parse p name src)

-- | For use in TypeErr to get the length of the tokentree to draw a squiggly
--   line under it.
ns_tokenTree :: Parser ()
ns_tokenTree = choice
    [str $> (), num $> (), ident $> (), ns_parens (many tokenTree) $> ()]
    where tokenTree = andSkipSpaceAfter ns_tokenTree

toplevels :: Parser ([Import], [Def], [TypeDef], [Extern])
toplevels = do
    space


@@ 176,22 184,22 @@ ns_expr :: Parser Expr
ns_expr = withPos $ choice [unit, estr, ebool, var, num, eConstructor, pexpr]
  where
    unit = ns_reserved "unit" $> Lit Unit
    num = do
        neg <- option False (char '-' $> True)
        a <- eitherP
            (try (Lexer.decimal <* notFollowedBy (char '.')))
            Lexer.float
        let e = either
                (\n -> Int (if neg then -n else n))
                (\x -> Double (if neg then -x else x))
                a
        pure (Lit e)
    estr = fmap (Lit . Str) str
    ebool = fmap (Lit . Bool) bool
    pexpr =
        ns_parens $ choice
            [funMatch, match, if', fun, let', typeAscr, box, deref, app]

num :: Parser Expr'
num = do
    neg <- option False (char '-' $> True)
    a <- eitherP (try (Lexer.decimal <* notFollowedBy (char '.'))) Lexer.float
    let e = either
            (\n -> Int (if neg then -n else n))
            (\x -> Double (if neg then -x else x))
            a
    pure (Lit e)

bool :: Parser Bool
bool = (ns_reserved "true" $> True) <|> (ns_reserved "false" $> False)


M src/TypeErr.hs => src/TypeErr.hs +8 -1
@@ 32,6 32,7 @@ data TypeErr
    | UnboundTVar SrcPos
    | WrongStartType (WithPos Scheme)
    | RecursiveVarDef (WithPos String)
    | TypeInstArityMismatch SrcPos String Int Int
    deriving Show

type Message = String


@@ 89,7 90,7 @@ printErr = \case
            $ ("Type `" ++ x ++ "` ")
            ++ "has infinite size due to recursion without indirection.\n"
            ++ "Insert a pointer at some point to make it representable."
    UndefType p x -> posd p big $ "Undefined type `" ++ x ++ "` in constructor"
    UndefType p x -> posd p big $ "Undefined type `" ++ x ++ "`."
    UnboundTVar p ->
        posd p defOrExpr
            $ "Could not fully infer type of expression.\n"


@@ 102,6 103,11 @@ printErr = \case
    RecursiveVarDef (WithPos p x) ->
        posd p var
            $ ("Non-function variable definition `" ++ x ++ "` is recursive.")
    TypeInstArityMismatch p t expected found ->
        posd p tokenTree
            $ ("Arity mismatch for instantiation of type `" ++ pretty t)
            ++ ("`.\nExpected " ++ show expected)
            ++ (", found " ++ show found)
  where
    -- | Used to handle that the position of the generated nested lambdas of a
    --   definition of the form `(define (foo a b ...) ...)` is set to the


@@ 117,6 123,7 @@ printErr = \case
    eConstructor = Parse.eConstructor <||> wholeLine
    big = Parse.ns_big
    wholeLine = many Mega.anySingle
    tokenTree = Parse.ns_tokenTree
    (<||>) pa pb = (Mega.try pa $> ()) <|> (pb $> ())

posd :: SrcPos -> Parse.Parser a -> Message -> IO ()