~jack/misc

ref: 792d1a6fb90d4bd33438b6e5e72a62737f8589d3 misc/stg/src/Language/STG/Eval.hs -rw-r--r-- 9.8 KiB
792d1a6fJack Kelly stg: implement the eval/apply machine a month ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeApplications #-}

module Language.STG.Eval where

import qualified Data.Map as M
import Language.STG.AST
import Relude
import qualified Text.Show as S

data StgState = StgState
  { sArgStack :: ArgStack,
    sRetStack :: RetStack,
    sUpdStack :: [(ArgStack, RetStack, Addr)], -- [UpdateFrame],
    sHeap :: Heap,
    sTopOfHeap :: Addr,
    sGlobalEnv :: Env,
    sGenSyms :: Gensyms,
    sCode :: Code
  }
  deriving (Show, Eq)

data Code
  = Eval Expr Env
  | Enter Addr
  | ReturnCon Con [Value]
  | ReturnInt Int
  deriving (Show, Eq)

newtype Addr = Addr Int deriving (Show, Eq, Ord)

inc :: Addr -> Addr
inc = coerce @Int . succ . coerce

data Value = VAddr Addr | VInt Int deriving (Show, Eq)

type ArgStack = [Value]

type RetStack = [(Alts, Env)]

type Env = Map Var Value

type Heap = Map Addr HeapEnt

data HeapEnt = HeapEnt LambdaForm [Value] deriving (Show, Eq)

newtype Gensyms = Gensyms [Var]

instance Show Gensyms where
  show _ = "Gensyms"

instance Eq Gensyms where
  Gensyms xs == Gensyms ys = viaNonEmpty head xs == viaNonEmpty head ys

initialState :: Prog -> StgState
initialState (Prog (Binds m)) =
  StgState
    { sCode = Eval (VarApp "main" []) [],
      sArgStack = [],
      sRetStack = [],
      sUpdStack = [],
      sHeap = heap,
      sTopOfHeap = topOfHeap,
      sGlobalEnv = globals,
      sGenSyms = Gensyms $ Var . ("_g" <>) . show <$> ([0 ..] :: [Int])
    }
  where
    (heap, topOfHeap, globals) =
      M.foldlWithKey f (M.empty, Addr 0, M.empty) m

    f :: (Heap, Addr, Env) -> Var -> LambdaForm -> (Heap, Addr, Env)
    f (h, a, g) v lf =
      ( M.insert a (HeapEnt lf []) h,
        inc a,
        M.insert v (VAddr a) g
      )

step :: StgState -> StgState
step s = case catMaybes $ sequenceA rules s of
  [s'] -> s'
  ss ->
    error $
      "Should only have one matching state (got: " <> show (length ss) <> ")"
  where
    rules =
      [ applyRule,
        enterClosureRule,
        letRule,
        letrecRule,
        caseRule,
        conAppRule,
        returnConRule,
        returnConAnonDefaultRule,
        returnConDefaultRule,
        evalIntRule,
        evalIntVarRule,
        returnIntRule,
        returnIntAnonDefaultRule,
        returnIntDefaultRule,
        primOpRule,
        enterUpdatableClosureRule,
        enterPartialLambdaRule
      ]

val :: Env -> Env -> Atom -> Value
val _ _ (ALit (Literal k)) = VInt k
val p s (AVar v) =
  fromMaybe (error $ "missing binding " <> show v) $
    M.lookup v p <|> M.lookup v s

applyRule :: StgState -> Maybe StgState
applyRule s = case sCode s of
  Eval (VarApp f xs) p -> case val p (sGlobalEnv s) (AVar f) of
    VAddr a ->
      Just $
        s
          { sCode = Enter a,
            sArgStack = map (val p (sGlobalEnv s)) xs <> sArgStack s
          }
    _ -> Nothing
  _ -> Nothing

enterClosureRule :: StgState -> Maybe StgState
enterClosureRule s = case sCode s of
  Enter a -> do
    HeapEnt LF {..} wsf <- M.lookup a $ sHeap s
    guard $ lfUpdateFlag == N
    guard $ length lfArgVars >= length (sArgStack s)
    let p = M.fromList $ fvEnts <> argEnts
        (wsa, as') = splitAt (length lfArgVars) $ sArgStack s
        fvEnts :: [(Var, Value)]
        fvEnts = zip lfFreeVars wsf

        argEnts :: [(Var, Value)]
        argEnts = zip lfArgVars wsa
    Just $
      s
        { sCode = Eval lfExpr p,
          sArgStack = as'
        }
  _ -> Nothing

letRule :: StgState -> Maybe StgState
letRule s = case sCode s of
  Eval (Let (Binds bs) e) p ->
    Just $
      s
        { sCode = Eval e p',
          sHeap = h',
          sTopOfHeap = topOfHeap'
        }
    where
      p' =
        fst . foldl' (\(p, a) v -> (M.insert v (VAddr a) p, inc a)) (p, sTopOfHeap s) $
          M.keys bs

      (h', topOfHeap') =
        foldl' (\(h, a) lf -> (M.insert a (HeapEnt lf $ map (p M.!) (lfFreeVars lf)) h, inc a)) (sHeap s, sTopOfHeap s) $ toList bs
  _ -> Nothing

letrecRule :: StgState -> Maybe StgState
letrecRule s = case sCode s of
  Eval (Letrec (Binds bs) e) p ->
    Just $
      s
        { sCode = Eval e p',
          sHeap = h',
          sTopOfHeap = topOfHeap'
        }
    where
      p' =
        fst . foldl' (\(p, a) v -> (M.insert v (VAddr a) p, inc a)) (p, sTopOfHeap s) $
          M.keys bs

      (h', topOfHeap') =
        foldl' (\(h, a) lf -> (M.insert a (HeapEnt lf $ map (p' M.!) (lfFreeVars lf)) h, inc a)) (sHeap s, sTopOfHeap s) $ toList bs
  _ -> Nothing

caseRule :: StgState -> Maybe StgState
caseRule s = case sCode s of
  Eval (Case e alts) p ->
    Just $
      s
        { sCode = Eval e p,
          sRetStack = (alts, p) : sRetStack s
        }
  _ -> Nothing

conAppRule :: StgState -> Maybe StgState
conAppRule s = case sCode s of
  Eval (ConApp c xs) p ->
    Just $
      s
        { sCode = ReturnCon c (val p (sGlobalEnv s) <$> xs)
        }
  _ -> Nothing

returnConRule :: StgState -> Maybe StgState
returnConRule s = case sCode s of
  ReturnCon c ws -> do
    ((AAlts alts _, p), rs') <- uncons $ sRetStack s
    AAlt _ vs e <- find (\(AAlt c' _ _) -> c == c') alts
    Just $
      s
        { sCode = Eval e . M.union p . M.fromList $ zip vs ws,
          sRetStack = rs'
        }
  _ -> Nothing

returnConAnonDefaultRule :: StgState -> Maybe StgState
returnConAnonDefaultRule s = case sCode s of
  ReturnCon c ws -> do
    ((AAlts alts (Just (Default Nothing def)), p), rs') <- uncons $ sRetStack s
    guard $ all (\(AAlt c' _ _) -> c /= c') alts
    Just $
      s
        { sCode = Eval def p,
          sRetStack = rs'
        }
  _ -> Nothing

returnConDefaultRule :: StgState -> Maybe StgState
returnConDefaultRule s = case sCode s of
  ReturnCon c ws -> do
    ((AAlts alts (Just (Default (Just v) def)), p), rs') <- uncons $ sRetStack s
    guard $ all (\(AAlt c' _ _) -> c /= c') alts
    Just $
      s
        { sCode = Eval def (M.insert v (VAddr a) p),
          sRetStack = rs',
          sHeap = h',
          sTopOfHeap = inc a,
          sGenSyms = Gensyms sGenSyms'
        }
    where
      a = sTopOfHeap s
      h' = M.insert a (HeapEnt (LF vs N [] . ConApp c $ map AVar vs) ws) $ sHeap s
      (vs, sGenSyms') = splitAt (length ws) . coerce $ sGenSyms s
  _ -> Nothing

evalIntRule :: StgState -> Maybe StgState
evalIntRule s = case sCode s of
  Eval (Lit (Literal k)) p -> Just $ s {sCode = ReturnInt k}
  _ -> Nothing

evalIntVarRule :: StgState -> Maybe StgState
evalIntVarRule s = case sCode s of
  Eval (VarApp v []) p -> do
    VInt k <- M.lookup v p
    Just $ s {sCode = ReturnInt k}
  _ -> Nothing

returnIntRule :: StgState -> Maybe StgState
returnIntRule s = case sCode s of
  ReturnInt k -> do
    ((PAlts alts Nothing, p), rs') <- uncons $ sRetStack s
    PAlt _ e <- find (\(PAlt (Literal k') _) -> k == k') alts
    Just $
      s
        { sCode = Eval e p,
          sRetStack = rs'
        }
  _ -> Nothing

returnIntAnonDefaultRule :: StgState -> Maybe StgState
returnIntAnonDefaultRule s = case sCode s of
  ReturnInt k -> do
    ((PAlts alts (Just (Default Nothing e)), p), rs') <- uncons $ sRetStack s
    guard $ all (\(PAlt (Literal k') _) -> k /= k') alts
    Just $
      s
        { sCode = Eval e p,
          sRetStack = rs'
        }
  _ -> Nothing

returnIntDefaultRule :: StgState -> Maybe StgState
returnIntDefaultRule s = case sCode s of
  ReturnInt k -> do
    ((PAlts alts (Just (Default (Just v) e)), p), rs') <- uncons $ sRetStack s
    guard $ all (\(PAlt (Literal k') _) -> k /= k') alts
    Just $
      s
        { sCode = Eval e $ M.insert v (VInt k) p,
          sRetStack = rs'
        }
  _ -> Nothing

primOpRule :: StgState -> Maybe StgState
primOpRule s = case sCode s of
  Eval (PrimApp prim [a1, a2]) p -> do
    VInt k1 <- case a1 of
      AVar v -> M.lookup v p
      ALit (Literal k) -> Just $ VInt k
    VInt k2 <- case a2 of
      AVar v -> M.lookup v p
      ALit (Literal k) -> Just $ VInt k
    Just $
      s
        { sCode = ReturnInt $ case prim of
            IntAdd -> k1 + k2
            IntSub -> k1 - k2
            IntMul -> k1 * k2
        }
  _ -> Nothing

enterUpdatableClosureRule :: StgState -> Maybe StgState
enterUpdatableClosureRule s = case sCode s of
  Enter a -> do
    HeapEnt LF {..} wsf <- M.lookup a $ sHeap s
    guard $ lfUpdateFlag == U
    guard $ null lfArgVars
    let p = M.fromList $ zip lfFreeVars wsf
    Just $
      s
        { sCode = Eval lfExpr p,
          sArgStack = [],
          sRetStack = [],
          sUpdStack = (sArgStack s, sRetStack s, a) : sUpdStack s
        }
  _ -> Nothing

returnConUpdateRule :: StgState -> Maybe StgState
returnConUpdateRule s = case sCode s of
  ReturnCon c ws -> do
    ((asu, rsu, au), us') <- uncons $ sUpdStack s
    let (vs, sGenSyms') = splitAt (length ws) . coerce $ sGenSyms s
        h' =
          M.insert au (HeapEnt (LF vs N [] . ConApp c $ AVar <$> vs) ws) $
            sHeap s
    Just $
      s
        { sCode = ReturnCon c ws,
          sArgStack = asu,
          sRetStack = rsu,
          sHeap = h'
        }
  _ -> Nothing

enterPartialLambdaRule :: StgState -> Maybe StgState
enterPartialLambdaRule s = case sCode s of
  Enter a -> do
    guard . null $ sRetStack s
    ((asu, rsu, au), us') <- uncons $ sUpdStack s
    let as = sArgStack s
    HeapEnt LF {..} wsf <- M.lookup a $ sHeap s
    guard $ lfUpdateFlag == N
    guard $ length as < length lfArgVars
    let (xs1, xs2) = splitAt (length as) lfArgVars
        Just (f, sGenSyms') = uncons . coerce $ sGenSyms s
        hu =
          M.insert
            au
            ( HeapEnt
                ( LF (f : xs1) N []
                    . VarApp f
                    $ AVar <$> xs1
                )
                (VAddr a : as)
            )
            $ sHeap s
    Just $
      s
        { sCode = Enter a,
          sArgStack = as ++ asu,
          sRetStack = rsu,
          sUpdStack = us',
          sHeap = hu,
          sGenSyms = Gensyms sGenSyms'
        }
  _ -> Nothing