~jojo/Carth

c763695a7f984f50a2151759f15c696779e78f02 — JoJo 1 year, 10 months ago 8c7156e
Check that non-func var defs aren't recursive

Because we can't compile that with strict evaluation!
4 files changed, 20 insertions(+), 8 deletions(-)

M src/AnnotAst.hs
M src/Ast.hs
M src/Infer.hs
M src/TypeErr.hs
M src/AnnotAst.hs => src/AnnotAst.hs +0 -4
@@ 68,10 68,6 @@ data Expr'
type Expr = WithPos Expr'

type Defs = Map String (Scheme, Expr)
-- type TypeDefs = Map String ([TVar], [[Type]])
type TypeDefs = Map String ([TVar], [(String, [Type])])
type Ctors = Map String (VariantIx, (String, [TVar]), [Type], Span)
type Externs = Map String Type

-- data Program = Program Expr Defs TypeDefs Externs
--     deriving (Show)

M src/Ast.hs => src/Ast.hs +7 -1
@@ 23,6 23,7 @@ module Ast
    , Extern(..)
    , Program(..)
    , startType
    , isFun
    )
where



@@ 132,7 133,7 @@ instance Eq Pat where
        _ -> False

instance FreeVars Def (Id Small) where
    freeVars (name, (_, body)) = Set.delete name (freeVars body)
    freeVars (_, (_, body)) = freeVars body

instance FreeVars Expr (Id Small) where
    freeVars = fvExpr


@@ 360,3 361,8 @@ idstr (Id (WithPos _ x)) = x

startType :: Type
startType = TFun (TPrim TUnit) (TPrim TUnit)

isFun :: Expr -> Bool
isFun (WithPos _ e) = case e of
    Fun _ _ -> True
    _ -> False

M src/Infer.hs => src/Infer.hs +9 -3
@@ 9,7 9,7 @@ import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State.Strict
import Data.Bifunctor
import Data.Graph (SCC(..), flattenSCC, stronglyConnComp)
import Data.Graph (SCC(..), stronglyConnComp)
import qualified Data.Map.Strict as Map
import Data.Map.Strict (Map)
import Data.Maybe


@@ 22,7 22,7 @@ import FreeVars
import Subst
import NonEmpty
import qualified Ast
import Ast (Id(..), IdCase(..), idstr, scmBody)
import Ast (Id(..), IdCase(..), idstr, scmBody, isFun)
import TypeErr
import AnnotAst hiding (Id)
import Match


@@ 116,7 116,10 @@ inferDefsComponents :: [SCC Ast.Def] -> Infer Defs
inferDefsComponents = \case
    [] -> pure Map.empty
    (scc : sccs) -> do
        let (idents, rhss) = unzip (flattenSCC scc)
        let (verts, isCyclic) = case scc of
                AcyclicSCC vert -> ([vert], False)
                CyclicSCC verts' -> (verts', True)
        let (idents, rhss) = unzip verts
        let (mayscms, bodies) = unzip rhss
        checkUserSchemes (catMaybes mayscms)
        let mayscms' = map (fmap unpos) mayscms


@@ 125,6 128,9 @@ inferDefsComponents = \case
        let scms = map
                (\(mayscm, t) -> fromMaybe (Forall Set.empty t) mayscm)
                (zip mayscms' ts)
        forM_ (zip idents bodies) $ \(Id name, body) ->
            when (not (isFun body) && isCyclic)
                $ throwError (RecursiveVarDef name)
        bodies' <-
            withLocals (zip names scms)
            $ forM (zip bodies scms)

M src/TypeErr.hs => src/TypeErr.hs +4 -0
@@ 31,6 31,7 @@ data TypeErr
    | UndefType SrcPos String
    | UnboundTVar SrcPos
    | WrongStartType (WithPos Scheme)
    | RecursiveVarDef (WithPos String)
    deriving Show

type Message = String


@@ 98,6 99,9 @@ prettyErr = \case
            $ "Incorrect type of `start`.\n"
            ++ ("Expected: " ++ pretty startType)
            ++ ("\nFound: " ++ pretty s)
    RecursiveVarDef (WithPos p x) ->
        posd p var
            $ ("Non-function variable definition `" ++ x ++ "` is recursive.")
  where
    -- | Used to handle that the position of the generated nested lambdas of a
    --   definition of the form `(define (foo a b ...) ...)` is set to the