~jojo/Carth

ref: db87ea9bdce9470b546bb3098f5d79a9b1023af1 Carth/src/Mono.hs -rw-r--r-- 7.2 KiB
db87ea9bJoJo Disclaim WIP status in readme 1 year, 10 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
{-# LANGUAGE TemplateHaskell, LambdaCase, TupleSections
           , TypeSynonymInstances, FlexibleInstances, MultiParamTypeClasses
           , FlexibleContexts#-}

-- | Monomorphization
module Mono (monomorphize) where

import Control.Applicative (liftA2, liftA3)
import Control.Lens (makeLenses, views, use, uses, modifying)
import Control.Monad.Reader
import Control.Monad.State
import Data.Functor
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.Maybe
import qualified Data.Set as Set
import Data.Set (Set)
import Data.Bitraversable

import Misc
import qualified DesugaredAst as An
import DesugaredAst (TVar(..), Scheme(..))
import MonoAst


data Env = Env
    { _envDefs :: Map String (Scheme, An.Expr)
    , _tvBinds :: Map TVar Type
    }
makeLenses ''Env

data Insts = Insts
    { _defInsts :: Map String (Map Type ([Type], Expr))
    , _tdefInsts :: Set TConst
    }
makeLenses ''Insts

-- | The monomorphization monad
type Mono = StateT Insts (Reader Env)

monomorphize :: An.Program -> Program
monomorphize (An.Program defs tdefs externs) = evalMono $ do
    externs' <- mapM (bimapM pure monotype) (Map.toList externs)
    (defs', _) <- monoLet defs (An.Var (An.TypedVar "start" An.startType))
    tdefs' <- instTypeDefs tdefs
    pure (Program defs' tdefs' externs')

evalMono :: Mono a -> a
evalMono ma = runReader (evalStateT ma initInsts) initEnv

initInsts :: Insts
initInsts = Insts Map.empty Set.empty

initEnv :: Env
initEnv = Env { _envDefs = Map.empty, _tvBinds = Map.empty }

mono :: An.Expr -> Mono Expr
mono = \case
    An.Lit c -> pure (Lit c)
    An.Var (An.TypedVar x t) -> do
        t' <- monotype t
        addDefInst x t'
        pure (Var (TypedVar x t'))
    An.App f a rt -> liftA3 App (mono f) (mono a) (monotype rt)
    An.If p c a -> liftA3 If (mono p) (mono c) (mono a)
    An.Fun p b -> monoFun p b
    An.Let ds b -> fmap (uncurry Let) (monoLet ds b)
    An.Match e cs tbody -> monoMatch e cs tbody
    An.Ction v span' inst as -> monoCtion v span' inst as
    An.Box x -> fmap Box (mono x)
    An.Deref x -> fmap Deref (mono x)

monoFun :: (String, An.Type) -> (An.Expr, An.Type) -> Mono Expr
monoFun (p, tp) (b, bt) = do
    parentInst <- uses defInsts (Map.lookup p)
    modifying defInsts (Map.delete p)
    tp' <- monotype tp
    b' <- mono b
    bt' <- monotype bt
    maybe (pure ()) (modifying defInsts . Map.insert p) parentInst
    pure (Fun (TypedVar p tp') (b', bt'))

monoLet :: An.Defs -> An.Expr -> Mono (Defs, Expr)
monoLet ds body = do
    let ks = Map.keys ds
    parentInsts <- uses defInsts (lookups ks)
    let newEmptyInsts = (fmap (const Map.empty) ds)
    modifying defInsts (Map.union newEmptyInsts)
    body' <- augment envDefs ds (mono body)
    dsInsts <- uses defInsts (lookups ks)
    modifying defInsts (Map.union (Map.fromList parentInsts))
    let ds' = Map.fromList $ do
            (name, dInsts) <- dsInsts
            (t, (us, dbody)) <- Map.toList dInsts
            pure (TypedVar name t, (us, dbody))
    pure (ds', body')

monoMatch :: An.Expr -> An.DecisionTree -> An.Type -> Mono Expr
monoMatch e dt tbody =
    liftA3 Match (mono e) (monoDecisionTree dt) (monotype tbody)

monoDecisionTree :: An.DecisionTree -> Mono DecisionTree
monoDecisionTree = \case
    An.DSwitch obj cs def -> do
        obj' <- monoAccess obj
        cs' <- mapM monoDecisionTree cs
        def' <- monoDecisionTree def
        pure (DSwitch obj' cs' def')
    An.DLeaf (bs, e) -> do
        let bs' = Map.toList bs
        let ks = map (\((An.TypedVar x _), _) -> x) bs'
        parentInsts <- uses defInsts (lookups ks)
        modifying defInsts (deletes ks)
        bs'' <- mapM
            (bimapM
                (\(An.TypedVar x t) -> fmap (TypedVar x) (monotype t))
                monoAccess
            )
            bs'
        e' <- mono e
        modifying defInsts (Map.union (Map.fromList parentInsts))
        pure (DLeaf (bs'', e'))

monoAccess :: An.Access -> Mono Access
monoAccess = \case
    An.Obj -> pure Obj
    An.As a span' ts ->
        liftA3 As (monoAccess a) (pure span') (mapM monotype ts)
    An.Sel i span' a -> fmap (Sel i span') (monoAccess a)
    An.ADeref a -> fmap ADeref (monoAccess a)

monoCtion :: VariantIx -> Span -> An.TConst -> [An.Expr] -> Mono Expr
monoCtion i span' (tdefName, tdefArgs) as = do
    tdefArgs' <- mapM monotype tdefArgs
    let tdefInst = (tdefName, tdefArgs')
    as' <- mapM mono as
    pure (Ction (i, span', tdefInst, as'))

addDefInst :: String -> Type -> Mono ()
addDefInst x t1 = do
    use defInsts <&> Map.lookup x >>= \case
        -- If x is not in insts, it's a function parameter. Ignore.
        Nothing -> pure ()
        Just xInsts -> when (not (Map.member t1 xInsts)) $ do
            (Forall _ t2, body) <- views
                envDefs
                (lookup' (ice (x ++ " not in defs")) x)
            _ <- mfix $ \body' -> do
                -- The instantiation must be in the environment when
                -- monomorphizing the body, or we may infinitely recurse.
                let boundTvs = bindTvs t2 t1
                    instTs = Map.elems boundTvs
                insertInst t1 (instTs, body')
                augment tvBinds boundTvs (mono body)
            pure ()
    where insertInst t b = modifying defInsts (Map.adjust (Map.insert t b) x)

bindTvs :: An.Type -> Type -> Map TVar Type
bindTvs a b = case (a, b) of
    (An.TVar v, t) -> Map.singleton v t
    (An.TFun p0 r0, TFun p1 r1) -> Map.union (bindTvs p0 p1) (bindTvs r0 r1)
    (An.TBox t0, TBox t1) -> bindTvs t0 t1
    (An.TPrim _, TPrim _) -> Map.empty
    (An.TConst (_, ts0), TConst (_, ts1)) ->
        Map.unions (zipWith bindTvs ts0 ts1)
    (An.TPrim _, _) -> err
    (An.TFun _ _, _) -> err
    (An.TBox _, _) -> err
    (An.TConst _, _) -> err
    where err = ice $ "bindTvs: " ++ show a ++ ", " ++ show b

monotype :: An.Type -> Mono Type
monotype = \case
    An.TVar v -> views tvBinds (lookup' (ice (show v ++ " not in tvBinds")) v)
    An.TPrim c -> pure (TPrim c)
    An.TFun a b -> liftA2 TFun (monotype a) (monotype b)
    An.TBox t -> fmap TBox (monotype t)
    An.TConst (c, ts) -> do
        ts' <- mapM monotype ts
        let tdefInst = (c, ts')
        modifying tdefInsts (Set.insert tdefInst)
        pure (TConst tdefInst)

instTypeDefs :: An.TypeDefs -> Mono TypeDefs
instTypeDefs tdefs = do
    insts <- uses tdefInsts Set.toList
    instTypeDefs' tdefs insts

instTypeDefs' :: An.TypeDefs -> [TConst] -> Mono TypeDefs
instTypeDefs' tdefs = \case
    [] -> pure []
    inst : insts -> do
        oldTdefInsts <- use tdefInsts
        tdef' <- instTypeDef tdefs inst
        newTdefInsts <- use tdefInsts
        let newInsts = Set.difference newTdefInsts oldTdefInsts
        tdefs' <- instTypeDefs' tdefs (Set.toList newInsts ++ insts)
        pure (tdef' : tdefs')
instTypeDef :: An.TypeDefs -> TConst -> Mono (TConst, [VariantTypes])
instTypeDef tdefs (x, ts) = do
    let (tvs, vs) = lookup' (ice "lookup' failed in instTypeDef") x tdefs
    vs' <- augment tvBinds (Map.fromList (zip tvs ts)) (mapM (mapM monotype) vs)
    pure ((x, ts), vs')

lookup' :: Ord k => v -> k -> Map k v -> v
lookup' = Map.findWithDefault

lookups :: Ord k => [k] -> Map k v -> [(k, v)]
lookups ks m = catMaybes (map (\k -> fmap (k, ) (Map.lookup k m)) ks)

deletes :: (Foldable t, Ord k) => t k -> Map k v -> Map k v
deletes = flip (foldr Map.delete)