~cypheon/Idris2

ref: 18f9c5484d9664e5355087d2253fcb20b60fc0ea Idris2/src/Compiler/LambdaLift.idr -rw-r--r-- 18.7 KiB
18f9c548 — Johann Rudloff [ refactor ] Pass through `Used vars` instead of creating and merging 3 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
module Compiler.LambdaLift

import Core.CompileExpr
import Core.Context
import Core.Core
import Core.TT

import Data.List
import Data.Vect
import Data.Maybe

%default covering

mutual
  public export
  -- lazy (lazy reason) represents if a function application is lazy (Just _)
  -- and if so why (eg. Just LInf, Just LLazy)
  data Lifted : List Name -> Type where
       LLocal : {idx : Nat} -> FC -> (0 p : IsVar x idx vars) -> Lifted vars
       -- A known function applied to exactly the right number of arguments,
       -- so the runtime can Just Go
       LAppName : FC -> (lazy : Maybe LazyReason) -> Name -> List (Lifted vars) -> Lifted vars
       -- A known function applied to too few arguments, so the runtime should
       -- make a closure and wait for the remaining arguments
       LUnderApp : FC -> Name -> (missing : Nat) ->
                   (args : List (Lifted vars)) -> Lifted vars
       -- A closure applied to one more argument (so, for example a closure
       -- which is waiting for another argument before it can run).
       -- The runtime should add the argument to the closure and run the result
       -- if it is now fully applied.
       LApp : FC -> (lazy : Maybe LazyReason) -> (closure : Lifted vars) -> (arg : Lifted vars) -> Lifted vars
       LLet : FC -> (x : Name) -> Lifted vars ->
              Lifted (x :: vars) -> Lifted vars
       LCon : FC -> Name -> (tag : Maybe Int) -> List (Lifted vars) -> Lifted vars
       LOp : {arity : _} ->
             FC -> (lazy : Maybe LazyReason) -> PrimFn arity -> Vect arity (Lifted vars) -> Lifted vars
       LExtPrim : FC -> (lazy : Maybe LazyReason) -> (p : Name) -> List (Lifted vars) -> Lifted vars
       LConCase : FC -> Lifted vars ->
                  List (LiftedConAlt vars) ->
                  Maybe (Lifted vars) -> Lifted vars
       LConstCase : FC -> Lifted vars ->
                    List (LiftedConstAlt vars) ->
                    Maybe (Lifted vars) -> Lifted vars
       LPrimVal : FC -> Constant -> Lifted vars
       LErased : FC -> Lifted vars
       LCrash : FC -> String -> Lifted vars

  public export
  data LiftedConAlt : List Name -> Type where
       MkLConAlt : Name -> (tag : Maybe Int) -> (args : List Name) ->
                   Lifted (args ++ vars) -> LiftedConAlt vars

  public export
  data LiftedConstAlt : List Name -> Type where
       MkLConstAlt : Constant -> Lifted vars -> LiftedConstAlt vars

public export
data LiftedDef : Type where
     -- We take the outer scope and the function arguments separately so that
     -- we don't have to reshuffle de Bruijn indices, which is expensive.
     -- This should be compiled as a function which takes 'args' first,
     -- then 'reverse scope'.
     -- (Sorry for the awkward API - it's to do with how the indices are
     -- arranged for the variables, and it oculd be expensive to reshuffle them!
     -- See Compiler.ANF for an example of how they get resolved to names)
     MkLFun : (args : List Name) -> -- function arguments
              (scope : List Name) -> -- outer scope
              Lifted (scope ++ args) -> LiftedDef
     MkLCon : (tag : Maybe Int) -> (arity : Nat) -> (nt : Maybe Nat) -> LiftedDef
     MkLForeign : (ccs : List String) ->
                  (fargs : List CFType) ->
                  CFType ->
                  LiftedDef
     MkLError : Lifted [] -> LiftedDef

showLazy : Maybe LazyReason -> String
showLazy = maybe "" $ (" " ++) . show

mutual
  export
  {vs : _} -> Show (Lifted vs) where
    show (LLocal {idx} _ p) = "!" ++ show (nameAt p)
    show (LAppName fc lazy n args)
        = show n ++ showLazy lazy ++ "(" ++ showSep ", " (map show args) ++ ")"
    show (LUnderApp fc n m args)
        = "<" ++ show n ++ " underapp " ++ show m ++ ">(" ++
          showSep ", " (map show args) ++ ")"
    show (LApp fc lazy c arg)
        = show c ++ showLazy lazy ++ " @ (" ++ show arg ++ ")"
    show (LLet fc x val sc)
        = "%let " ++ show x ++ " = " ++ show val ++ " in " ++ show sc
    show (LCon fc n t args)
        = "%con " ++ show n ++ "(" ++ showSep ", " (map show args) ++ ")"
    show (LOp fc lazy op args)
        = "%op " ++ show op ++ showLazy lazy ++ "(" ++ showSep ", " (toList (map show args)) ++ ")"
    show (LExtPrim fc lazy p args)
        = "%extprim " ++ show p ++ showLazy lazy ++ "(" ++ showSep ", " (map show args) ++ ")"
    show (LConCase fc sc alts def)
        = "%case " ++ show sc ++ " of { "
             ++ showSep "| " (map show alts) ++ " " ++ show def
    show (LConstCase fc sc alts def)
        = "%case " ++ show sc ++ " of { "
             ++ showSep "| " (map show alts) ++ " " ++ show def
    show (LPrimVal _ x) = show x
    show (LErased _) = "___"
    show (LCrash _ x) = "%CRASH(" ++ show x ++ ")"

  export
  {vs : _} -> Show (LiftedConAlt vs) where
    show (MkLConAlt n t args sc)
        = "%conalt " ++ show n ++
             "(" ++ showSep ", " (map show args) ++ ") => " ++ show sc

  export
  {vs : _} -> Show (LiftedConstAlt vs) where
    show (MkLConstAlt c sc)
        = "%constalt(" ++ show c ++ ") => " ++ show sc

export
Show LiftedDef where
  show (MkLFun args scope exp)
      = show args ++ show (reverse scope) ++ ": " ++ show exp
  show (MkLCon tag arity pos)
      = "Constructor tag " ++ show tag ++ " arity " ++ show arity ++
        maybe "" (\n => " (newtype by " ++ show n ++ ")") pos
  show (MkLForeign ccs args ret)
      = "Foreign call " ++ show ccs ++ " " ++
        show args ++ " -> " ++ show ret
  show (MkLError exp) = "Error: " ++ show exp


data Lifts : Type where

record LDefs where
  constructor MkLDefs
  basename : Name -- top level name we're lifting from
  defs : List (Name, LiftedDef) -- new definitions we made
  nextName : Int -- name of next definition to lift

genName : {auto l : Ref Lifts LDefs} ->
          Core Name
genName
    = do ldefs <- get Lifts
         let i = nextName ldefs
         put Lifts (record { nextName = i + 1 } ldefs)
         pure $ mkName (basename ldefs) i
  where
    mkName : Name -> Int -> Name
    mkName (NS ns b) i = NS ns (mkName b i)
    mkName (UN n) i = MN n i
    mkName (DN _ n) i = mkName n i
    mkName (CaseBlock outer inner) i = MN ("case block in " ++ outer ++ " (" ++ show inner ++ ")") i
    mkName (WithBlock outer inner) i = MN ("with block in " ++ outer ++ " (" ++ show inner ++ ")") i
    mkName n i = MN (show n) i

unload : FC -> (lazy : Maybe LazyReason) -> Lifted vars -> List (Lifted vars) -> Core (Lifted vars)
unload fc _ f [] = pure f
-- only outermost LApp must be lazy as rest will be closures
unload fc lazy f (a :: as) = unload fc Nothing (LApp fc lazy f a) as

record Used (vars : List Name) where
  constructor MkUsed
  used : Vect (length vars) Bool

initUsed : {vars : _} -> Used vars
initUsed {vars} = MkUsed (replicate (length vars) False)

lengthDistributesOverAppend
  : (xs, ys : List a)
  -> length (xs ++ ys) = length xs + length ys
lengthDistributesOverAppend [] ys = Refl
lengthDistributesOverAppend (x :: xs) ys =
  cong S $ lengthDistributesOverAppend xs ys

weakenUsed : {outer : _} -> Used vars -> Used (outer ++ vars)
weakenUsed {outer} (MkUsed xs) =
  MkUsed (rewrite lengthDistributesOverAppend outer vars in
         (replicate (length outer) False ++ xs))

contractUsed : (Used (x::vars)) -> Used vars
contractUsed (MkUsed xs) = MkUsed (tail xs)

contractUsedMany : {remove : _} ->
                   (Used (remove ++ vars)) ->
                   Used vars
contractUsedMany {remove=[]} x = x
contractUsedMany {remove=(r::rs)} x = contractUsedMany {remove=rs} (contractUsed x)

markUsed : {vars : _} ->
           (idx : Nat) ->
           {0 prf : IsVar x idx vars} ->
           Used vars ->
           Used vars
markUsed {vars} {prf} idx (MkUsed us) =
  let newUsed = replaceAt (finIdx prf) True us in
  MkUsed newUsed
    where
    finIdx : {vars : _} -> {idx : _} ->
               (0 prf : IsVar x idx vars) ->
               Fin (length vars)
    finIdx {idx=Z} First = FZ
    finIdx {idx=S x} (Later l) = FS (finIdx l)

getUnused : Used vars ->
            Vect (length vars) Bool
getUnused (MkUsed uv) = map not uv

total
dropped : (vars : List Name) ->
          (drop : Vect (length vars) Bool) ->
          List Name
dropped [] _ = []
dropped (x::xs) (False::us) = x::(dropped xs us)
dropped (x::xs) (True::us) = dropped xs us

mutual
  makeLam : {auto l : Ref Lifts LDefs} ->
            {vars : _} ->
            {doLazyAnnots : Bool} ->
            {default Nothing lazy : Maybe LazyReason} ->
            FC -> (bound : List Name) ->
            CExp (bound ++ vars) -> Core (Lifted vars)
  makeLam fc bound (CLam _ x sc') = makeLam fc {doLazyAnnots} {lazy} (x :: bound) sc'
  makeLam {vars} fc bound sc
      = do scl <- liftExp {doLazyAnnots} {lazy} sc
           -- Find out which variables aren't used in the new definition, and
           -- do not abstract over them in the new definition.
           let scUsedL = usedVars initUsed scl
               unusedContracted = contractUsedMany {remove=bound} scUsedL
               unused = getUnused unusedContracted
               scl' = dropUnused {outer=bound} unused scl
           n <- genName
           ldefs <- get Lifts
           put Lifts (record { defs $= ((n, MkLFun (dropped vars unused) bound scl') ::) } ldefs)
           pure $ LUnderApp fc n (length bound) (allVars fc vars unused)
    where
        allPrfs : (vs : List Name) -> (unused : Vect (length vs) Bool) -> List (Var vs)
        allPrfs [] _ = []
        allPrfs (v :: vs) (False::uvs) = MkVar First :: map weaken (allPrfs vs uvs)
        allPrfs (v :: vs) (True::uvs) = map weaken (allPrfs vs uvs)

        -- apply to all the variables. 'First' will be first in the last, which
        -- is good, because the most recently bound name is the first argument to
        -- the resulting function
        allVars : FC -> (vs : List Name) -> (unused : Vect (length vs) Bool) -> List (Lifted vs)
        allVars fc vs unused = map (\ (MkVar p) => LLocal fc p) (allPrfs vs unused)

-- if doLazyAnnots = True then annotate function application with laziness
-- otherwise use old behaviour (thunk is a function)
  liftExp : {vars : _} ->
            {auto l : Ref Lifts LDefs} ->
            {doLazyAnnots : Bool} ->
            {default Nothing lazy : Maybe LazyReason} ->
            CExp vars -> Core (Lifted vars)
  liftExp (CLocal fc prf) = pure $ LLocal fc prf
  liftExp (CRef fc n) = pure $ LAppName fc lazy n [] -- probably shouldn't happen!
  liftExp (CLam fc x sc) = makeLam {doLazyAnnots} {lazy} fc [x] sc
  liftExp (CLet fc x _ val sc) = pure $ LLet fc x !(liftExp {doLazyAnnots} val) !(liftExp {doLazyAnnots} sc)
  liftExp (CApp fc (CRef _ n) args) -- names are applied exactly in compileExp
      = pure $ LAppName fc lazy n !(traverse (liftExp {doLazyAnnots}) args)
  liftExp (CApp fc f args)
      = unload fc lazy !(liftExp {doLazyAnnots} f) !(traverse (liftExp {doLazyAnnots}) args)
  liftExp (CCon fc n t args) = pure $ LCon fc n t !(traverse (liftExp {doLazyAnnots}) args)
  liftExp (COp fc op args)
      = pure $ LOp fc lazy op !(traverseArgs args)
    where
      traverseArgs : Vect n (CExp vars) -> Core (Vect n (Lifted vars))
      traverseArgs [] = pure []
      traverseArgs (a :: as) = pure $ !(liftExp {doLazyAnnots} a) :: !(traverseArgs as)
  liftExp (CExtPrim fc p args) = pure $ LExtPrim fc lazy p !(traverse (liftExp {doLazyAnnots}) args)
  liftExp (CForce fc lazy tm) = if doLazyAnnots
    then liftExp {doLazyAnnots} {lazy = Nothing} tm
    else liftExp {doLazyAnnots} (CApp fc tm [CErased fc])
  liftExp (CDelay fc lazy tm) = if doLazyAnnots
    then liftExp {doLazyAnnots} {lazy = Just lazy} tm
    else liftExp {doLazyAnnots} (CLam fc (MN "act" 0) (weaken tm))
  liftExp (CConCase fc sc alts def)
      = pure $ LConCase fc !(liftExp {doLazyAnnots} sc) !(traverse (liftConAlt {lazy}) alts)
                           !(traverseOpt (liftExp {doLazyAnnots}) def)
    where
      liftConAlt : {default Nothing lazy : Maybe LazyReason} ->
                   CConAlt vars -> Core (LiftedConAlt vars)
      liftConAlt (MkConAlt n t args sc) = pure $ MkLConAlt n t args !(liftExp {doLazyAnnots} {lazy} sc)
  liftExp (CConstCase fc sc alts def)
      = pure $ LConstCase fc !(liftExp {doLazyAnnots} sc) !(traverse liftConstAlt alts)
                             !(traverseOpt (liftExp {doLazyAnnots}) def)
    where
      liftConstAlt : {default Nothing lazy : Maybe LazyReason} ->
                     CConstAlt vars -> Core (LiftedConstAlt vars)
      liftConstAlt (MkConstAlt c sc) = pure $ MkLConstAlt c !(liftExp {doLazyAnnots} {lazy} sc)
  liftExp (CPrimVal fc c) = pure $ LPrimVal fc c
  liftExp (CErased fc) = pure $ LErased fc
  liftExp (CCrash fc str) = pure $ LCrash fc str

  usedVars : {vars : _} ->
             {auto l : Ref Lifts LDefs} ->
             Used vars ->
             Lifted vars ->
             Used vars
  usedVars used (LLocal {idx} fc prf) =
    markUsed {prf} idx used
  usedVars used (LAppName fc lazy n args) =
    foldl (usedVars {vars}) used args
  usedVars used (LUnderApp fc n miss args) =
    foldl (usedVars {vars}) used args
  usedVars used (LApp fc lazy c arg) =
    usedVars (usedVars used arg) c
  usedVars used (LLet fc x val sc) =
    let innerUsed = contractUsed $ usedVars (weakenUsed {outer=[x]} used) sc in
        usedVars innerUsed val
  usedVars used (LCon fc n tag args) =
    foldl (usedVars {vars}) used args
  usedVars used (LOp fc lazy fn args) =
    foldl (usedVars {vars}) used args
  usedVars used (LExtPrim fc lazy fn args) =
    foldl (usedVars {vars}) used args
  usedVars used (LConCase fc sc alts def) =
      let defUsed = maybe used (usedVars used {vars}) def
          scDefUsed = usedVars defUsed sc in
          foldl usedConAlt scDefUsed alts
    where
      usedConAlt : {default Nothing lazy : Maybe LazyReason} ->
                   Used vars -> LiftedConAlt vars -> Used vars
      usedConAlt used (MkLConAlt n tag args sc) =
        contractUsedMany {remove=args} (usedVars (weakenUsed used) sc)

  usedVars used (LConstCase fc sc alts def) =
      let defUsed = maybe used (usedVars used {vars}) def
          scDefUsed = usedVars defUsed sc in
          foldl usedConstAlt scDefUsed alts
    where
      usedConstAlt : {default Nothing lazy : Maybe LazyReason} ->
                     Used vars -> LiftedConstAlt vars -> Used vars
      usedConstAlt used (MkLConstAlt c sc) = usedVars used sc

  usedVars used (LPrimVal _ _) = used
  usedVars used (LErased _) = used
  usedVars used (LCrash _ _) = used

  dropIdx : {vars : _} ->
            {idx : _} ->
            (outer : List Name) ->
            (unused : Vect (length vars) Bool) ->
            (0 p : IsVar x idx (outer ++ vars)) ->
            Var (outer ++ (dropped vars unused))
  dropIdx [] (False::_) First = MkVar First
  dropIdx [] (True::_) First = assert_total $
    idris_crash "INTERNAL ERROR: Referenced variable marked as unused"
  dropIdx [] (False::rest) (Later p) = Var.later $ dropIdx [] rest p
  dropIdx [] (True::rest) (Later p) = dropIdx [] rest p
  dropIdx (_::xs) unused First = MkVar First
  dropIdx (_::xs) unused (Later p) = Var.later $ dropIdx xs unused p

  dropUnused : {vars : _} ->
               {auto l : Ref Lifts LDefs} ->
               {outer : List Name} ->
               (unused : Vect (length vars) Bool) ->
               (l : Lifted (outer ++ vars)) ->
               Lifted (outer ++ (dropped vars unused))
  dropUnused _ (LPrimVal fc val) = LPrimVal fc val
  dropUnused _ (LErased fc) = LErased fc
  dropUnused _ (LCrash fc msg) = LCrash fc msg
  dropUnused {outer} unused (LLocal fc p) =
    let (MkVar p') = dropIdx outer unused p in LLocal fc p'
  dropUnused unused (LCon fc n tag args) =
    let args' = map (dropUnused unused) args in
        LCon fc n tag args'
  dropUnused {outer} unused (LLet fc n val sc) =
    let val' = dropUnused unused val
        sc' = dropUnused {outer=n::outer} (unused) sc in
        LLet fc n val' sc'
  dropUnused unused (LApp fc lazy c arg) =
    let c' = dropUnused unused c
        arg' = dropUnused unused arg in
        LApp fc lazy c' arg'
  dropUnused unused (LOp fc lazy fn args) =
    let args' = map (dropUnused unused) args in
        LOp fc lazy fn args'
  dropUnused unused (LExtPrim fc lazy n args) =
    let args' = map (dropUnused unused) args in
        LExtPrim fc lazy n args'
  dropUnused unused (LAppName fc lazy n args) =
    let args' = map (dropUnused unused) args in
        LAppName fc lazy n args'
  dropUnused unused (LUnderApp fc n miss args) =
    let args' = map (dropUnused unused) args in
        LUnderApp fc n miss args'
  dropUnused {vars} {outer} unused (LConCase fc sc alts def) =
    let alts' = map dropConCase alts in
        LConCase fc (dropUnused unused sc) alts' (map (dropUnused unused) def)
    where
      dropConCase : LiftedConAlt (outer ++ vars) ->
                    LiftedConAlt (outer ++ (dropped vars unused))
      dropConCase (MkLConAlt n t args sc) =
        let sc' = (rewrite sym $ appendAssociative args outer vars in sc)
            droppedSc = dropUnused {vars=vars} {outer=args++outer} unused sc' in
        MkLConAlt n t args (rewrite appendAssociative args outer (dropped vars unused) in droppedSc)
  dropUnused {vars} {outer} unused (LConstCase fc sc alts def) =
    let alts' = map dropConstCase alts in
        LConstCase fc (dropUnused unused sc) alts' (map (dropUnused unused) def)
    where
      dropConstCase : LiftedConstAlt (outer ++ vars) ->
                      LiftedConstAlt (outer ++ (dropped vars unused))
      dropConstCase (MkLConstAlt c val) = MkLConstAlt c (dropUnused unused val)

export
liftBody : {vars : _} -> {doLazyAnnots : Bool} ->
           Name -> CExp vars -> Core (Lifted vars, List (Name, LiftedDef))
liftBody n tm
    = do l <- newRef Lifts (MkLDefs n [] 0)
         tml <- liftExp {doLazyAnnots} {l} tm
         ldata <- get Lifts
         pure (tml, defs ldata)

lambdaLiftDef : (doLazyAnnots : Bool) -> Name -> CDef -> Core (List (Name, LiftedDef))
lambdaLiftDef doLazyAnnots n (MkFun args exp)
    = do (expl, defs) <- liftBody {doLazyAnnots} n exp
         pure ((n, MkLFun args [] expl) :: defs)
lambdaLiftDef _ n (MkCon t a nt) = pure [(n, MkLCon t a nt)]
lambdaLiftDef _ n (MkForeign ccs fargs ty) = pure [(n, MkLForeign ccs fargs ty)]
lambdaLiftDef doLazyAnnots n (MkError exp)
    = do (expl, defs) <- liftBody {doLazyAnnots} n exp
         pure ((n, MkLError expl) :: defs)

-- Return the lambda lifted definitions required for the given name.
-- If the name hasn't been compiled yet (via CompileExpr.compileDef) then
-- this will return an empty list
-- An empty list an error, because on success you will always get at least
-- one definition, the lifted definition for the given name.
export
lambdaLift : {auto c : Ref Ctxt Defs} ->
             (doLazyAnnots : Bool) ->
             Name -> Core (List (Name, LiftedDef))
lambdaLift doLazyAnnots n
    = do defs <- get Ctxt
         Just def <- lookupCtxtExact n (gamma defs) | Nothing => pure []
         let Just cexpr = compexpr def              | Nothing => pure []
         lambdaLiftDef doLazyAnnots n cexpr