~jojo/Carth

ref: ae1d242d7d48292779dcbd953e5864bb4211e1ca Carth/src/Infer.hs -rw-r--r-- 18.1 KiB
ae1d242dJoJo Update stackage release & use default-extensions in cabal file 7 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
{-# LANGUAGE TemplateHaskell, DataKinds #-}

module Infer (inferTopDefs, checkType', checkType'') where

import Prelude hiding (span)
import Lens.Micro.Platform (makeLenses, over, view, mapped, to, Lens')
import Control.Applicative hiding (Const(..))
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State.Strict
import Control.Monad.Writer
import Data.Bifunctor
import Data.Graph (SCC(..), stronglyConnComp)
import Data.List hiding (span)
import qualified Data.Map as Map
import Data.Map (Map)
import Data.Maybe
import qualified Data.Set as Set
import Control.Arrow ((>>>))

import Misc
import SrcPos
import FreeVars
import Subst
import qualified Parsed
import Parsed (Id(..), IdCase(..), idstr, defLhs)
import Err
import Inferred hiding (Id)
import TypeAst hiding (TConst)


newtype ExpectedType = Expected Type
data FoundType = Found SrcPos Type

type Constraint = (ExpectedType, FoundType)

data Env = Env
    { _envTypeDefs :: TypeDefs
    -- Separarate global defs and local defs, because `generalize` only has to look at
    -- local defs.
    , _envGlobDefs :: Map String Scheme
    , _envLocalDefs :: Map String Scheme
    -- | Maps a constructor to its variant index in the type definition it
    --   constructs, the signature/left-hand-side of the type definition, the
    --   types of its parameters, and the span (number of constructors) of the
    --   datatype
    , _envCtors :: Map String (VariantIx, (String, [TVar]), [Type], Span)
    }
makeLenses ''Env

type FreshTVs = [String]

type Infer a = WriterT [Constraint] (ReaderT Env (StateT FreshTVs (Except TypeErr))) a

------------------------------------------------------------------------------------------
-- Inference
------------------------------------------------------------------------------------------

inferTopDefs :: TypeDefs -> Ctors -> Externs -> [Parsed.Def] -> Except TypeErr Defs
inferTopDefs tdefs ctors externs defs =
    let initEnv = Env
            { _envTypeDefs = tdefs
            , _envGlobDefs = Map.union (fmap (Forall Set.empty . fst) externs)
                                       builtinVirtuals
            , _envLocalDefs = Map.empty
            , _envCtors = ctors
            }
        freshTvs =
            let ls = "abcdehjkpqrstuvxyz"
                ns = map show [1 :: Word .. 99]
                vs = [ l : n | l <- ls, n <- ns ] ++ [ l : v | l <- ls, v <- vs ]
            in  vs
    in  evalStateT
            (runReaderT (fmap fst (runWriterT (inferDefs envGlobDefs defs))) initEnv)
            freshTvs

checkType :: SrcPos -> Parsed.Type -> Infer Type
checkType pos t = view envTypeDefs >>= \tds -> checkType' tds pos t

checkType' :: MonadError TypeErr m => TypeDefs -> SrcPos -> Parsed.Type -> m Type
checkType' tdefs = checkType'' (\x -> fmap (length . fst) (Map.lookup x tdefs))

checkType''
    :: MonadError TypeErr m => (String -> Maybe Int) -> SrcPos -> Parsed.Type -> m Type
checkType'' tdefsParams pos = go
  where
    go = \case
        Parsed.TVar v -> pure (TVar v)
        Parsed.TPrim p -> pure (TPrim p)
        Parsed.TConst tc -> fmap TConst (checkTConst tc)
        Parsed.TFun f a -> liftA2 TFun (go f) (go a)
        Parsed.TBox t -> fmap TBox (go t)
    checkTConst (x, inst) = case tdefsParams x of
        Just expectedN -> do
            let foundN = length inst
            if (expectedN == foundN)
                then do
                    inst' <- mapM go inst
                    pure (x, inst')
                else throwError (TypeInstArityMismatch pos x expectedN foundN)
        Nothing -> throwError (UndefType pos x)

inferDefs :: Lens' Env (Map String Scheme) -> [Parsed.Def] -> Infer Defs
inferDefs envDefs defs = do
    checkNoDuplicateDefs defs
    let ordered = orderDefs defs
    foldr
        (\scc inferRest -> do
            def <- inferComponent scc
            Topo rest <- augment envDefs (Map.fromList (defSigs def)) inferRest
            pure (Topo (def : rest))
        )
        (pure (Topo []))
        ordered

checkNoDuplicateDefs :: [Parsed.Def] -> Infer ()
checkNoDuplicateDefs = checkNoDuplicateDefs' Set.empty
  where
    checkNoDuplicateDefs' already = uncons >>> fmap (first defLhs) >>> \case
        Just (Id (WithPos p x), ds) -> if Set.member x already
            then throwError (ConflictingVarDef p x)
            else checkNoDuplicateDefs' (Set.insert x already) ds
        Nothing -> pure ()

-- For unification to work properly with mutually recursive functions,
-- we need to create a dependency graph of non-recursive /
-- directly-recursive functions and groups of mutual functions. We do
-- this by creating a directed acyclic graph (DAG) of strongly
-- connected components (SCC), where a node is a definition and an
-- edge is a reference to another definition. For each SCC, we infer
-- types for all the definitions / the single definition before
-- generalizing.
orderDefs :: [Parsed.Def] -> [SCC Parsed.Def]
orderDefs = stronglyConnComp . graph
    where graph = map (\d -> (d, defLhs d, Set.toList (freeVars d)))

inferComponent :: SCC Parsed.Def -> Infer Def
inferComponent = \case
    AcyclicSCC vert -> fmap VarDef (inferNonrecDef vert)
    CyclicSCC verts -> fmap RecDefs (inferRecDefs verts)

inferNonrecDef :: Parsed.Def -> Infer VarDef
inferRecDefs :: [Parsed.Def] -> Infer RecDefs
(inferNonrecDef, inferRecDefs) = (inferNonrecDef', inferRecDefs')
  where
    inferNonrecDef' (Parsed.FunDef dpos lhs mayscm params body) =
        -- FIXME: Just wanted to get things working, but this isn't really better than
        --        doing the fold in the parser. Handle this such that we don't have to
        --        assign the definition position to the nested lambdas.
        inferNonrecDef' $ Parsed.VarDef dpos lhs mayscm $ foldr
            (\p b -> WithPos dpos (Parsed.FunMatch [(p, b)]))
            body
            params
    inferNonrecDef' (Parsed.VarDef dpos lhs mayscm body) = do
        t <- fresh
        (body', cs) <- listen $ inferDef t lhs mayscm (getPos body) (infer body)
        sub <- lift $ lift $ lift $ solve cs
        env <- view envLocalDefs
        let scm = generalize (substEnv sub env) (subst sub t)
        let body'' = substExpr sub body'
        pure (idstr lhs, WithPos dpos (scm, body''))

    inferRecDefs' ds = do
        ts <- replicateM (length ds) fresh
        let (names, poss) = unzip $ flip map ds $ \case
                Parsed.FunDef p x _ _ _ -> (idstr x, p)
                Parsed.VarDef p x _ _ -> (idstr x, p)
        let dummyDefs = Map.fromList (zip names (map (Forall Set.empty) ts))
        (fs, cs) <- listen $ augment envLocalDefs dummyDefs $ zipWithM inferRecDef ts ds
        sub <- lift $ lift $ lift $ solve cs
        env <- view envLocalDefs
        let scms = map (generalize (substEnv sub env) . subst sub) ts
        let fs' = map (mapPosd (substFunMatch sub)) fs
        pure (zip names (zipWith3 (curry . WithPos) poss scms fs'))

    inferRecDef :: Type -> Parsed.Def -> Infer (WithPos FunMatch)
    inferRecDef t = \case
        Parsed.FunDef fpos lhs mayscm params body ->
            let (initps, lastp) = fromJust $ unsnoc params
            in  fmap (WithPos fpos) $ inferDef t lhs mayscm fpos $ inferFunMatch $ foldr
                    (\p cs -> [(p, WithPos fpos (Parsed.FunMatch cs))])
                    [(lastp, body)]
                    initps
        Parsed.VarDef fpos lhs mayscm (WithPos _ (Parsed.FunMatch cs)) ->
            fmap (WithPos fpos) $ inferDef t lhs mayscm fpos (inferFunMatch cs)
        Parsed.VarDef _ (Id lhs) _ _ -> throwError (RecursiveVarDef lhs)

    inferDef t lhs mayscm bodyPos inferBody = do
        checkScheme (idstr lhs) mayscm >>= \case
            Just (Forall _ scmt) -> unify (Expected scmt) (Found bodyPos t)
            Nothing -> pure ()
        (t', body') <- inferBody
        unify (Expected t) (Found bodyPos t')
        pure body'

-- | Verify that user-provided type signature schemes are valid
checkScheme :: String -> Maybe Parsed.Scheme -> Infer (Maybe Scheme)
checkScheme = curry $ \case
    ("main", Nothing) -> pure (Just (Forall Set.empty mainType))
    ("main", Just s@(Parsed.Forall pos vs t)) | Set.size vs /= 0 || t /= mainType ->
        throwError (WrongMainType pos s)
    (_, Nothing) -> pure Nothing
    (_, Just (Parsed.Forall pos vs t)) -> do
        t' <- checkType pos t
        let s1 = Forall vs t'
        env <- view envLocalDefs
        let s2 = generalize env t'
        if (s1 == s2) then pure (Just s1) else throwError (InvalidUserTypeSig pos s1 s2)

infer :: Parsed.Expr -> Infer (Type, Expr)
infer (WithPos pos e) = fmap (second (WithPos pos)) $ case e of
    Parsed.Lit l -> pure (litType l, Lit l)
    Parsed.Var (Id (WithPos p "_")) -> throwError (FoundHole p)
    Parsed.Var x@(Id x') -> fmap (\t -> (t, Var (TypedVar x' t))) (lookupEnv x)
    Parsed.App f a -> do
        ta <- fresh
        tr <- fresh
        (tf', f') <- infer f
        unify (Expected (TFun ta tr)) (Found (getPos f) tf')
        (ta', a') <- infer a
        unify (Expected ta) (Found (getPos a) ta')
        pure (tr, App f' a' tr)
    Parsed.If p c a -> do
        (tp, p') <- infer p
        (tc, c') <- infer c
        (ta, a') <- infer a
        unify (Expected tBool) (Found (getPos p) tp)
        unify (Expected tc) (Found (getPos a) ta)
        pure (tc, If p' c' a')
    Parsed.Let1 def body -> inferLet1 pos def body
    Parsed.Let defs body ->
        -- FIXME: positions
        let (def, defs') = fromJust $ uncons defs
        in  inferLet1 pos def $ foldr (\d b -> WithPos pos (Parsed.Let1 d b)) body defs'
    Parsed.LetRec defs b -> do
        Topo defs' <- inferDefs envLocalDefs defs
        let withDef def inferX = do
                (tx, x') <- withLocals (defSigs def) inferX
                pure (tx, WithPos pos (Let def x'))
        fmap (second unpos) (foldr withDef (infer b) defs')
    Parsed.TypeAscr x t -> do
        (tx, WithPos _ x') <- infer x
        t' <- checkType pos t
        unify (Expected t') (Found (getPos x) tx)
        pure (t', x')
    Parsed.Match matchee cases -> inferMatch pos matchee cases
    Parsed.FunMatch cases -> fmap (second FunMatch) (inferFunMatch cases)
    Parsed.Ctor c -> inferExprConstructor c
    Parsed.Sizeof t -> fmap ((TPrim TNatSize, ) . Sizeof) (checkType pos t)

inferLet1 :: SrcPos -> Parsed.DefLike -> Parsed.Expr -> Infer (Type, Expr')
inferLet1 pos defl body = case defl of
    Parsed.Def def -> do
        def' <- inferNonrecDef def
        (t, body') <- augment1 envLocalDefs (defSig def') (infer body)
        pure (t, Let (VarDef def') body')
    Parsed.Deconstr pat matchee -> inferMatch pos matchee [(pat, body)]

inferMatch :: SrcPos -> Parsed.Expr -> [(Parsed.Pat, Parsed.Expr)] -> Infer (Type, Expr')
inferMatch pos matchee cases = do
    (tmatchee, matchee') <- infer matchee
    (tbody, cases') <- inferCases (Expected tmatchee) cases
    let f = WithPos pos (FunMatch (cases', tmatchee, tbody))
    pure (tbody, App f matchee' tbody)

inferFunMatch :: [(Parsed.Pat, Parsed.Expr)] -> Infer (Type, FunMatch)
inferFunMatch cases = do
    tpat <- fresh
    (tbody, cases') <- inferCases (Expected tpat) cases
    pure (TFun tpat tbody, (cases', tpat, tbody))

-- | All the patterns must be of the same types, and all the bodies must be of
--   the same type.
inferCases
    :: ExpectedType -- Type of matchee. Expected type of pattern.
    -> [(Parsed.Pat, Parsed.Expr)]
    -> Infer (Type, Cases)
inferCases tmatchee cases = do
    (tpats, tbodies, cases') <- fmap unzip3 (mapM inferCase cases)
    forM_ tpats (unify tmatchee)
    tbody <- fresh
    forM_ tbodies (unify (Expected tbody))
    pure (tbody, cases')

inferCase :: (Parsed.Pat, Parsed.Expr) -> Infer (FoundType, FoundType, (Pat, Expr))
inferCase (p, b) = do
    (tp, p', pvs) <- inferPat p
    let pvs' = map (bimap (Parsed.idstr) (Forall Set.empty . TVar)) (Map.toList pvs)
    (tb, b') <- withLocals pvs' (infer b)
    pure (Found (getPos p) tp, Found (getPos b) tb, (p', b'))

-- | Returns the type of the pattern; the pattern in the Pat format that the
--   Match module wants, and a Map from the variables bound in the pattern to
--   fresh schemes.
inferPat :: Parsed.Pat -> Infer (Type, Pat, Map (Id 'Small) TVar)
inferPat pat = fmap (\(t, p, ss) -> (t, WithPos (getPos pat) p, ss)) (inferPat' pat)
  where
    inferPat' = \case
        Parsed.PConstruction pos c ps -> inferPatConstruction pos c ps
        Parsed.PInt _ n -> pure (TPrim TIntSize, intToPCon n 64, Map.empty)
        Parsed.PStr _ s ->
            let span' = ice "span of Con with VariantStr"
                p = PCon (Con (VariantStr s) span' []) []
            in  pure (typeStr, p, Map.empty)
        Parsed.PVar (Id (WithPos _ "_")) -> do
            tv <- fresh
            pure (tv, PWild, Map.empty)
        Parsed.PVar x@(Id x') -> do
            tv <- fresh'
            pure (TVar tv, PVar (TypedVar x' (TVar tv)), Map.singleton x tv)
        Parsed.PBox _ p -> do
            (tp', p', vs) <- inferPat p
            pure (TBox tp', PBox p', vs)
    intToPCon n w = PCon
        (Con { variant = VariantIx (fromIntegral n)
             , span = 2 ^ (w :: Integer)
             , argTs = []
             }
        )
        []

inferPatConstruction
    :: SrcPos -> Id 'Big -> [Parsed.Pat] -> Infer (Type, Pat', Map (Id 'Small) TVar)
inferPatConstruction pos c cArgs = do
    (variantIx, tdefLhs, cParams, cSpan) <- lookupEnvConstructor c
    let arity = length cParams
    let nArgs = length cArgs
    unless (arity == nArgs) (throwError (CtorArityMismatch pos (idstr c) arity nArgs))
    (tdefInst, cParams') <- instantiateConstructorOfTypeDef tdefLhs cParams
    let t = TConst tdefInst
    (cArgTs, cArgs', cArgsVars) <- fmap unzip3 (mapM inferPat cArgs)
    cArgsVars' <- nonconflictingPatVarDefs cArgsVars
    forM_ (zip3 cParams' cArgTs cArgs) $ \(cParamT, cArgT, cArg) ->
        unify (Expected cParamT) (Found (getPos cArg) cArgT)
    let con = Con { variant = VariantIx variantIx, span = cSpan, argTs = cArgTs }
    pure (t, PCon con cArgs', cArgsVars')

nonconflictingPatVarDefs :: [Map (Id 'Small) TVar] -> Infer (Map (Id 'Small) TVar)
nonconflictingPatVarDefs = flip foldM Map.empty $ \acc ks ->
    case listToMaybe (Map.keys (Map.intersection acc ks)) of
        Just (Id (WithPos pos v)) -> throwError (ConflictingPatVarDefs pos v)
        Nothing -> pure (Map.union acc ks)

inferExprConstructor :: Id 'Big -> Infer (Type, Expr')
inferExprConstructor c = do
    (variantIx, tdefLhs, cParams, cSpan) <- lookupEnvConstructor c
    (tdefInst, cParams') <- instantiateConstructorOfTypeDef tdefLhs cParams
    let t = foldr TFun (TConst tdefInst) cParams'
    pure (t, Ctor variantIx cSpan tdefInst cParams')

instantiateConstructorOfTypeDef :: (String, [TVar]) -> [Type] -> Infer (TConst, [Type])
instantiateConstructorOfTypeDef (tName, tParams) cParams = do
    tVars <- mapM (const fresh) tParams
    let cParams' = map (subst (Map.fromList (zip tParams tVars))) cParams
    pure ((tName, tVars), cParams')

lookupEnvConstructor :: Id 'Big -> Infer (VariantIx, (String, [TVar]), [Type], Span)
lookupEnvConstructor (Id (WithPos pos cx)) =
    view (envCtors . to (Map.lookup cx)) >>= maybe (throwError (UndefCtor pos cx)) pure

litType :: Const -> Type
litType = \case
    Int _ -> TPrim TIntSize
    F64 _ -> TPrim TF64
    Str _ -> typeStr

typeStr :: Type
typeStr = TConst ("Str", [])

lookupEnv :: Id 'Small -> Infer Type
lookupEnv (Id (WithPos pos x)) = do
    eg <- view envGlobDefs
    el <- view envLocalDefs
    case (Map.lookup x el <|> Map.lookup x eg) of
        Just scm -> instantiate scm
        Nothing -> throwError (UndefVar pos x)

withLocals :: [(String, Scheme)] -> Infer a -> Infer a
withLocals = augment envLocalDefs . Map.fromList

instantiate :: Scheme -> Infer Type
instantiate (Forall params t) = do
    let params' = Set.toList params
    args <- mapM (const fresh) params'
    pure (subst (Map.fromList (zip params' args)) t)

generalize :: Map String Scheme -> Type -> Scheme
generalize env t = Forall (Set.difference (ftv t) (ftvEnv env)) t
  where
    ftvEnv env = Set.unions (map ftvScheme (Map.elems env))
    ftvScheme (Forall tvs t) = Set.difference (ftv t) tvs

substEnv :: Subst -> Map String Scheme -> Map String Scheme
substEnv s = over (mapped . scmBody) (subst s)

fresh :: Infer Type
fresh = fmap TVar fresh'

fresh' :: Infer TVar
fresh' = fmap TVImplicit (gets head <* modify tail)

unify :: ExpectedType -> FoundType -> Infer ()
unify e f = tell [(e, f)]

------------------------------------------------------------------------------------------
-- Constraint solver
------------------------------------------------------------------------------------------

data UnifyErr = UInfType TVar Type | UFailed Type Type

solve :: [Constraint] -> Except TypeErr Subst
solve = solve' Map.empty
  where
    solve' :: Subst -> [Constraint] -> Except TypeErr Subst
    solve' sub1 = \case
        [] -> pure sub1
        (Expected et, Found pos ft) : cs -> do
            sub2 <- withExcept (toTypeErr pos et ft) (unifies et ft)
            solve' (composeSubsts sub2 sub1) (map (substConstraint sub2) cs)

    substConstraint sub (Expected t1, Found pos t2) =
        (Expected (subst sub t1), Found pos (subst sub t2))

toTypeErr :: SrcPos -> Type -> Type -> UnifyErr -> TypeErr
toTypeErr pos t1 t2 = \case
    UInfType a t -> InfType pos t1 t2 a t
    UFailed t'1 t'2 -> UnificationFailed pos t1 t2 t'1 t'2

unifies :: Type -> Type -> Except UnifyErr Subst
unifies = curry $ \case
    (TPrim a, TPrim b) | a == b -> pure Map.empty
    (TConst (c0, ts0), TConst (c1, ts1)) | c0 == c1 -> if length ts0 /= length ts1
        then ice "lengths of TConst params differ in unify"
        else unifiesMany ts0 ts1
    (TVar a, TVar b) | a == b -> pure Map.empty
    (TVar a, t) | occursIn a t -> throwError (UInfType a t)
    -- Do not allow "override" of explicit (user given) type variables.
    (a@(TVar (TVExplicit _)), b@(TVar (TVImplicit _))) -> unifies b a
    (a@(TVar (TVExplicit _)), b) -> throwError (UFailed a b)
    (TVar a, t) -> pure (Map.singleton a t)
    (t, TVar a) -> unifies (TVar a) t
    (TFun t1 t2, TFun u1 u2) -> unifiesMany [t1, t2] [u1, u2]
    (TBox t, TBox u) -> unifies t u
    (t1, t2) -> throwError (UFailed t1 t2)

unifiesMany :: [Type] -> [Type] -> Except UnifyErr Subst
unifiesMany ts us = foldM
    (\s (t, u) -> fmap (flip composeSubsts s) (unifies (subst s t) (subst s u)))
    Map.empty
    (zip ts us)

occursIn :: TVar -> Type -> Bool
occursIn a t = Set.member a (ftv t)