~jojo/Carth

ref: 89feacb8828cbb1a3030454a6ecc25f7e21a23fa Carth/src/Check.hs -rw-r--r-- 9.4 KiB
89feacb8JoJo Move QuickCheck dep from lib to test in package.yaml 2 years ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
{-# LANGUAGE LambdaCase, DataKinds, TupleSections, FlexibleContexts #-}

module Check (typecheck) where

import Prelude hiding (span)
import Control.Monad.Except
import Control.Monad.Reader
import Data.Bifunctor
import Data.Bitraversable
import Data.Foldable
import Control.Applicative
import qualified Data.Map.Strict as Map
import Data.Map.Strict (Map)
import qualified Data.Set as Set
import Data.Set (Set)
import Data.Maybe

import Misc
import SrcPos
import Subst
import qualified Ast
import Ast (Id(..), IdCase(..), idstr, Type(..), TVar(..), TPrim(..))
import TypeErr
import AnnotAst (VariantIx)
import qualified AnnotAst as An
import Match
import Infer
import qualified DesugaredAst as Des


typecheck :: Ast.Program -> Either TypeErr Des.Program
typecheck (Ast.Program defs tdefs externs) = runExcept $ do
    (tdefs', ctors) <- checkTypeDefs tdefs
    (externs', inferred, substs) <- inferTopDefs ctors externs defs
    let substd = substTopDefs substs inferred
    checkTypeVarsBound substd
    let mTypeDefs = fmap (map fst . snd) tdefs'
    desugared <- compileDecisionTreesAndDesugar mTypeDefs substd
    checkStartDefined desugared
    let tdefs'' = fmap (second (map snd)) tdefs'
    pure (Des.Program desugared tdefs'' externs')
  where
    checkStartDefined ds =
        when (not (Map.member "start" ds)) (throwError StartNotDefined)

checkTypeDefs :: [Ast.TypeDef] -> Except TypeErr (An.TypeDefs, An.Ctors)
checkTypeDefs tdefs = do
    (tdefs', ctors) <- checkTypeDefsNoConflicting tdefs
    let tdefs'' = fmap (second (map snd)) tdefs'
    forM_ (Map.toList tdefs')
        $ \tdef -> checkTConstsDefs tdefs'' tdef *> assertNoRec tdefs' tdef
    pure (tdefs'', ctors)
  where
    -- | Check that constructurs don't refer to undefined types and that TConsts
    --   are of correct arity.
    checkTConstsDefs tds (_, (_, cs)) = forM_ cs (checkTConstsCtor tds)
    checkTConstsCtor tds (cpos, (_, ts)) = forM_ ts (checkType tds cpos)
    checkType tds cpos = \case
        TVar _ -> pure ()
        TPrim _ -> pure ()
        TConst tc -> checkTConst tds cpos tc
        TFun f a -> checkType tds cpos f *> checkType tds cpos a
        TBox t -> checkType tds cpos t
    checkTConst tds cpos (x, inst) = case Map.lookup x tds of
        Just (tvs, _) -> do
            let (expectedN, foundN) = (length tvs, length inst)
            when (not (expectedN == foundN)) $ throwError
                (TypeInstArityMismatch cpos x expectedN foundN)
        Nothing -> throwError (UndefType cpos x)
    -- | Check that type definitions are not recursive without indirection and
    --   that constructors don't refer to undefined types.
    assertNoRec tds (x, (_, cs)) = assertNoRecCtors tds x Map.empty cs
    assertNoRecCtors tds x s =
        mapM_ $ \(cpos, (_, ts)) ->
            forM_ ts (assertNoRecType tds x cpos . subst s)
    assertNoRecType tds x cpos = \case
        TVar _ -> pure ()
        TPrim _ -> pure ()
        TConst (y, ts) -> do
            when (x == y) $ throwError (RecTypeDef x cpos)
            let (tvs, cs) = tds Map.! y
            let substs = Map.fromList (zip tvs ts)
            assertNoRecCtors tds x substs cs
        TFun _ _ -> pure ()
        TBox _ -> pure ()

-- | Check that there are no conflicting type names or constructor names.
checkTypeDefsNoConflicting
    :: [Ast.TypeDef]
    -> Except
           TypeErr
           ( Map String ([TVar], [(SrcPos, (String, [Type]))])
           , Map String (VariantIx, (String, [TVar]), [Type], Span)
           )
checkTypeDefsNoConflicting =
    flip foldM (builtinDataTypes, builtinConstructors)
        $ \(tds', csAcc) td@(Ast.TypeDef x _ _) -> do
            when (Map.member (idstr x) tds') (throwError (ConflictingTypeDef x))
            (td', cs) <- checkTypeDef td
            case listToMaybe (Map.elems (Map.intersection cs csAcc)) of
                Just (cId, _) -> throwError (ConflictingCtorDef cId)
                Nothing ->
                    pure
                        ( uncurry Map.insert td' tds'
                        , Map.union (fmap snd cs) csAcc
                        )

checkTypeDef
    :: Ast.TypeDef
    -> Except
           TypeErr
           ( (String, ([TVar], [(SrcPos, (String, [Type]))]))
           , Map
                 String
                 (Id 'Big, (VariantIx, (String, [TVar]), [Type], Span))
           )
checkTypeDef (Ast.TypeDef x' ps (Ast.ConstructorDefs cs)) = do
    let x = idstr x'
    let ps' = map TVExplicit ps
    let cs' = map (\(Id (WithPos p y), ts) -> (p, (y, ts))) cs
    let cSpan = fromIntegral (length cs)
    cs''' <- foldM
        (\cs'' (i, (cx, cps)) -> if Map.member (idstr cx) cs''
            then throwError (ConflictingCtorDef cx)
            else pure
                (Map.insert (idstr cx) (cx, (i, (x, ps'), cps, cSpan)) cs'')
        )
        Map.empty
        (zip [0 ..] cs)
    pure ((x, (ps', cs')), cs''')

builtinDataTypes :: Map String ([TVar], [(SrcPos, (String, [Type]))])
builtinDataTypes = Map.fromList
    (map (\(x, ps, cs) -> (x, (ps, map (dummyPos, ) cs))) builtinDataTypes')

builtinConstructors :: Map String (VariantIx, (String, [TVar]), [Type], Span)
builtinConstructors = Map.unions (map builtinConstructors' builtinDataTypes')

builtinConstructors'
    :: (String, [TVar], [(String, [Type])])
    -> Map String (VariantIx, (String, [TVar]), [Type], Span)
builtinConstructors' (x, ps, cs) =
    let cSpan = fromIntegral (length cs)
    in
        foldl'
            (\csAcc (i, (cx, cps)) ->
                Map.insert cx (i, (x, ps), cps, cSpan) csAcc
            )
            Map.empty
            (zip [0 ..] cs)

builtinDataTypes' :: [(String, [TVar], [(String, [Type])])]
builtinDataTypes' =
    [ ( "Array"
      , [TVImplicit 0]
      , [("Array", [TBox (TVar (TVImplicit 0)), TPrim TNat])]
      )
    , ("Str", [], [("Str", [TConst ("Array", [TPrim TNat8])])])
    , ( "Pair"
      , [TVImplicit 0, TVImplicit 1]
      , [("Pair", [TVar (TVImplicit 0), TVar (TVImplicit 1)])]
      )
    ]

type Bound = ReaderT (Set TVar) (Except TypeErr) ()

-- TODO: Many of these positions are weird and kind of arbitrary, man. They may
--       not align with where the type variable is actually detected.
checkTypeVarsBound :: An.Defs -> Except TypeErr ()
checkTypeVarsBound ds = runReaderT (boundInDefs ds) Set.empty
  where
    boundInDefs :: An.Defs -> Bound
    boundInDefs = mapM_ boundInDef
    boundInDef ((An.Forall tvs _), e) =
        local (Set.union tvs) (boundInExpr e)
    boundInExpr (WithPos pos e) = case e of
        An.Lit _ -> pure ()
        An.Var (An.TypedVar _ t) -> boundInType pos t
        An.App f a rt -> do
            boundInExpr f
            boundInExpr a
            boundInType pos rt
        An.If p c a -> do
            boundInExpr p
            boundInExpr c
            boundInExpr a
        An.Let lds b -> do
            boundInDefs lds
            boundInExpr b
        An.FunMatch cs pt bt -> do
            boundInCases cs
            boundInType pos pt
            boundInType pos bt
        An.Ctor _ _ (_, instTs) ts -> do
            forM_ instTs (boundInType pos)
            forM_ ts (boundInType pos)
        An.Box x -> boundInExpr x
        An.Deref x -> boundInExpr x
    boundInType :: SrcPos -> An.Type -> Bound
    boundInType pos = \case
        TVar tv -> do
            bound <- ask
            when (not (Set.member tv bound)) (throwError (UnboundTVar pos))
        TPrim _ -> pure ()
        TConst (_, ts) -> forM_ ts (boundInType pos)
        TFun ft at -> forM_ [ft, at] (boundInType pos)
        TBox t -> boundInType pos t
    boundInCases cs = forM_ cs (bimapM boundInPat boundInExpr)
    boundInPat (WithPos pos pat) = case pat of
        An.PVar (An.TypedVar _ t) -> boundInType pos t
        An.PWild -> pure ()
        An.PCon con ps -> boundInCon pos con *> forM_ ps boundInPat
        An.PBox p -> boundInPat p
    boundInCon pos (Con _ _ ts) = forM_ ts (boundInType pos)

compileDecisionTreesAndDesugar
    :: MTypeDefs -> An.Defs -> Except TypeErr Des.Defs
compileDecisionTreesAndDesugar tdefs = compDefs
  where
    compDefs = mapM compDef
    compDef = bimapM pure compExpr
    compExpr :: An.Expr -> Except TypeErr Des.Expr
    compExpr (WithPos pos e) = case e of
        An.Lit c -> pure (Des.Lit c)
        An.Var (An.TypedVar (WithPos _ x) t) ->
            pure (Des.Var (Des.TypedVar x t))
        An.App f a tr -> liftA3 Des.App (compExpr f) (compExpr a) (pure tr)
        An.If p c a -> liftA3 Des.If (compExpr p) (compExpr c) (compExpr a)
        An.Let lds b -> liftA2 Des.Let (compDefs lds) (compExpr b)
        An.FunMatch cs tp tb -> do
            cs' <- mapM (secondM compExpr) cs
            case runExceptT (toDecisionTree tdefs pos tp cs') of
                Nothing -> pure (Des.Absurd tb)
                Just e -> do
                    dt <- liftEither e
                    let p = "#x"
                        v = Des.Var (Des.TypedVar p tp)
                        b = Des.Match v dt tb
                    pure (Des.Fun (p, tp) (b, tb))
        An.Ctor v span' inst ts ->
            let
                xs = map
                    (\n -> "#x" ++ show n)
                    (take (length ts) [0 :: Word ..])
                params = zip xs ts
                args = map (Des.Var . uncurry Des.TypedVar) params
            in pure $ snd $ foldr
                (\(p, pt) (bt, b) -> (TFun pt bt, Des.Fun (p, pt) (b, bt)))
                (TConst inst, Des.Ction v span' inst args)
                params
        An.Box x -> fmap Des.Box (compExpr x)
        An.Deref x -> fmap Des.Deref (compExpr x)