~jojo/Carth

ref: 55fb4f948f1f3797078b584dc60b4f7dd68b37ed Carth/src/Check.hs -rw-r--r-- 10.0 KiB
55fb4f94JoJo Check `cast` in Infer instead of Gen 4 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
{-# LANGUAGE DataKinds #-}

module Check (typecheck) where

import Prelude hiding (span)
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State.Strict
import Data.Bifunctor
import Data.Bitraversable
import Data.Foldable
import Control.Applicative
import qualified Data.Map as Map
import Data.Map (Map)
import qualified Data.Set as Set
import Data.Set (Set)

import Misc
import SrcPos
import Subst
import qualified Parsed
import Parsed (Id(..), TVar(..), idstr)
import Err
import qualified Inferred
import Match
import Infer
import TypeAst
import qualified Checked
import Checked (withPos, noPos, Virt(..))


typecheck :: Parsed.Program -> Either TypeErr Checked.Program
typecheck (Parsed.Program defs tdefs externs) = runExcept $ do
    (tdefs', ctors) <- checkTypeDefs tdefs
    externs' <- checkExterns tdefs' externs
    inferred <- inferTopDefs tdefs' ctors externs' defs
    checkTypeVarsBound inferred
    let mTypeDefs = fmap (map (unpos . fst) . snd) tdefs'
    compiled <- compileDecisionTrees mTypeDefs inferred
    checkMainDefined compiled
    let tdefs'' = fmap (second (map snd)) tdefs'
    pure (Checked.Program compiled tdefs'' externs')
  where
    checkMainDefined ds = when (not (elem "main" (map fst (Checked.flattenDefs ds))))
                               (throwError MainNotDefined)

type CheckTypeDefs a
    = ReaderT
          (Map String Int)
          (StateT (Inferred.TypeDefs, Inferred.Ctors) (Except TypeErr))
          a

checkTypeDefs :: [Parsed.TypeDef] -> Except TypeErr (Inferred.TypeDefs, Inferred.Ctors)
checkTypeDefs tdefs = do
    let tdefsParams = Map.union (fmap (length . fst) builtinDataTypes) $ Map.fromList
            (map (\(Parsed.TypeDef x ps _) -> (idstr x, length ps)) tdefs)
    (tdefs', ctors) <- execStateT (runReaderT (forM_ tdefs checkTypeDef) tdefsParams)
                                  (builtinDataTypes, builtinConstructors)
    forM_ (Map.toList tdefs') (assertNoRec tdefs')
    pure (tdefs', ctors)

checkTypeDef :: Parsed.TypeDef -> CheckTypeDefs ()
checkTypeDef (Parsed.TypeDef (Parsed.Id (WithPos xpos x)) ps cs) = do
    tAlreadyDefined <- gets (Map.member x . fst)
    when tAlreadyDefined (throwError (ConflictingTypeDef xpos x))
    let ps' = map TVExplicit ps
    cs' <- checkCtors (x, ps') cs
    modify (first (Map.insert x (ps', cs')))

checkCtors
    :: (String, [TVar])
    -> Parsed.ConstructorDefs
    -> CheckTypeDefs [(Inferred.Id, [Inferred.Type])]
checkCtors parent (Parsed.ConstructorDefs cs) =
    let cspan = fromIntegral (length cs) in mapM (checkCtor cspan) (zip [0 ..] cs)
  where
    checkCtor cspan (i, (Id c'@(WithPos pos c), ts)) = do
        cAlreadyDefined <- gets (Map.member c . snd)
        when cAlreadyDefined (throwError (ConflictingCtorDef pos c))
        ts' <- mapM (checkType pos) ts
        modify (second (Map.insert c (i, parent, ts', cspan)))
        pure (c', ts')
    checkType pos t = ask >>= \tdefs -> checkType'' (\x -> Map.lookup x tdefs) pos t

builtinDataTypes :: Inferred.TypeDefs
builtinDataTypes = Map.fromList $ map
    (\(x, ps, cs) -> (x, (ps, map (first (WithPos (SrcPos "<builtin>" 0 0 Nothing))) cs)))
    builtinDataTypes'

builtinConstructors :: Inferred.Ctors
builtinConstructors = Map.unions (map builtinConstructors' builtinDataTypes')
  where
    builtinConstructors' (x, ps, cs) =
        let cSpan = fromIntegral (length cs)
        in  foldl'
                (\csAcc (i, (cx, cps)) -> Map.insert cx (i, (x, ps), cps, cSpan) csAcc)
                Map.empty
                (zip [0 ..] cs)

builtinDataTypes' :: [(String, [TVar], [(String, [Inferred.Type])])]
builtinDataTypes' =
    [ ( "Array"
      , [TVImplicit "a"]
      , [ ( "Array"
          , [Inferred.TBox (Inferred.TVar (TVImplicit "a")), Inferred.TPrim TNatSize]
          )
        ]
      )
    , ("Str", [], [("Str", [tArray (Inferred.TPrim (TNat 8))])])
    , ( "Cons"
      , [TVImplicit "a", TVImplicit "b"]
      , [("Cons", [Inferred.TVar (TVImplicit "a"), Inferred.TVar (TVImplicit "b")])]
      )
    , ("Unit", [], [unit'])
    , ("RealWorld", [], [("UnsafeRealWorld", [])])
    , ("Bool", [], [("False", []), ("True", [])])
    , ( "IO"
      , [TVImplicit "a"]
      , [ ( "IO"
          , [ Inferred.TFun (tc ("RealWorld", [])) $ tc
                  ( "Cons"
                  , [ Inferred.TVar (TVImplicit "a")
                    , tc ("Cons", [tc ("RealWorld", []), tc unit'])
                    ]
                  )
            ]
          )
        ]
      )
    ]
  where
    tc = Inferred.TConst
    unit' = ("Unit", [])

assertNoRec
    :: Inferred.TypeDefs
    -> (String, ([TVar], [(Inferred.Id, [Inferred.Type])]))
    -> Except TypeErr ()
assertNoRec tdefs' (x, (_, ctors)) = assertNoRec' ctors Map.empty
  where
    assertNoRec' cs s =
        forM_ cs $ \(WithPos cpos _, cts) -> forM_ cts (assertNoRecType cpos . subst s)
    assertNoRecType cpos = \case
        Inferred.TConst (y, ts) -> do
            when (x == y) $ throwError (RecTypeDef x cpos)
            let (tvs, cs) = tdefs' Map.! y
            let substs = Map.fromList (zip tvs ts)
            assertNoRec' cs substs
        _ -> pure ()

checkExterns :: Inferred.TypeDefs -> [Parsed.Extern] -> Except TypeErr Inferred.Externs
checkExterns tdefs = fmap (Map.union Inferred.builtinExterns . Map.fromList)
    . mapM checkExtern
  where
    checkExtern (Parsed.Extern name t) = do
        t' <- checkType' tdefs (getPos name) t
        case Set.lookupMin (Inferred.ftv t') of
            Just tv -> throwError (ExternNotMonomorphic name tv)
            Nothing -> pure (idstr name, (t', getPos name))

type Bound = ReaderT (Set TVar) (Except TypeErr) ()

-- TODO: Many of these positions are weird and kind of arbitrary, man. They may
--       not align with where the type variable is actually detected.
checkTypeVarsBound :: Inferred.Defs -> Except TypeErr ()
checkTypeVarsBound ds = runReaderT (boundInDefs ds) Set.empty
  where
    boundInDefs :: Inferred.Defs -> Bound
    boundInDefs = mapM_ (secondM boundInDef) . Inferred.flattenDefs
    boundInDef (WithPos _ ((Inferred.Forall tvs _ _), e)) =
        local (Set.union tvs) (boundInExpr e)
    boundInExpr (WithPos pos e) = case e of
        Inferred.Lit _ -> pure ()
        Inferred.Var (_, Inferred.TypedVar _ t) -> boundInType pos t
        Inferred.App f a rt -> do
            boundInExpr f
            boundInExpr a
            boundInType pos rt
        Inferred.If p c a -> do
            boundInExpr p
            boundInExpr c
            boundInExpr a
        Inferred.Let ld b -> do
            mapM_ (secondM boundInDef) (Inferred.defToVarDefs ld)
            boundInExpr b
        Inferred.FunMatch (cs, pt, bt) -> do
            boundInCases cs
            boundInType pos pt
            boundInType pos bt
        Inferred.Ctor _ _ (_, instTs) ts -> do
            forM_ instTs (boundInType pos)
            forM_ ts (boundInType pos)
        Inferred.Sizeof _t -> pure ()
    boundInType :: SrcPos -> Inferred.Type -> Bound
    boundInType pos = \case
        Inferred.TVar tv -> do
            bound <- ask
            when (not (Set.member tv bound)) (throwError (UnboundTVar pos))
        Inferred.TPrim _ -> pure ()
        Inferred.TConst (_, ts) -> forM_ ts (boundInType pos)
        Inferred.TFun ft at -> forM_ [ft, at] (boundInType pos)
        Inferred.TBox t -> boundInType pos t
    boundInCases cs = forM_ cs (bimapM boundInPat boundInExpr)
    boundInPat (WithPos pos pat) = case pat of
        Inferred.PVar (Inferred.TypedVar _ t) -> boundInType pos t
        Inferred.PWild -> pure ()
        Inferred.PCon con ps -> boundInCon pos con *> forM_ ps boundInPat
        Inferred.PBox p -> boundInPat p
    boundInCon pos (Con _ _ ts) = forM_ ts (boundInType pos)

compileDecisionTrees :: MTypeDefs -> Inferred.Defs -> Except TypeErr Checked.Defs
compileDecisionTrees tdefs = compDefs
  where
    compDefs (Topo defs) = fmap Topo $ mapM compDef defs

    compDef :: Inferred.Def -> Except TypeErr Checked.Def
    compDef = \case
        Inferred.VarDef (lhs, WithPos p rhs) ->
            fmap (Checked.VarDef . (lhs, ) . WithPos p) (secondM compExpr rhs)
        Inferred.RecDefs ds -> fmap Checked.RecDefs $ flip mapM ds $ secondM $ mapPosdM
            (secondM compFunMatch)

    compFunMatch :: WithPos Inferred.FunMatch -> Except TypeErr (WithPos Checked.Fun)
    compFunMatch (WithPos pos (cs, tp, tb)) = do
        cs' <- mapM (secondM compExpr) cs
        let p = "#x"
        fmap (WithPos pos) $ case runExceptT (toDecisionTree tdefs pos tp cs') of
            Nothing -> pure ((p, tp), (noPos (Checked.Absurd tb), tb))
            Just e -> do
                dt <- liftEither e
                let v = noPos (Checked.Var (NonVirt, Checked.TypedVar p tp))
                    b = noPos (Checked.Match v dt tb)
                pure ((p, tp), (b, tb))

    compExpr :: Inferred.Expr -> Except TypeErr Checked.Expr
    compExpr (WithPos pos ex) = fmap (withPos pos) $ case ex of
        Inferred.Lit c -> pure (Checked.Lit c)
        Inferred.Var (virt, Inferred.TypedVar (WithPos _ x) t) ->
            pure (Checked.Var (virt, Checked.TypedVar x t))
        Inferred.App f a tr -> liftA3 Checked.App (compExpr f) (compExpr a) (pure tr)
        Inferred.If p c a -> liftA3 Checked.If (compExpr p) (compExpr c) (compExpr a)
        Inferred.Let ld b -> liftA2 Checked.Let (compDef ld) (compExpr b)
        Inferred.FunMatch fm ->
            fmap (Checked.Fun . unpos) (compFunMatch (WithPos pos fm))
        Inferred.Ctor v span' inst ts ->
            let
                xs = map (\n -> "x" ++ show n) (take (length ts) [0 ..] :: [Word])
                params = zip xs ts
                args = map
                    (noPos . Checked.Var . (NonVirt, ) . uncurry Checked.TypedVar)
                    params
            in
                pure $ snd $ foldr
                    (\(p, pt) (bt, b) ->
                        (Inferred.TFun pt bt, Checked.Fun ((p, pt), (withPos pos b, bt)))
                    )
                    (Inferred.TConst inst, Checked.Ction v span' inst args)
                    params
        Inferred.Sizeof t -> pure (Checked.Sizeof t)