~cypheon/rapid

722af631d9cbc6265eaba92a3f085f86f7847541 — Johann Rudloff 11 months ago ed831e3 experiment/local-continuations1
First attempt at optimising local continuations (does not work)
M src/Compiler/CPS.idr => src/Compiler/CPS.idr +18 -8
@@ 22,22 22,27 @@ data ContName : Type where
  BigK : Int -> ContName
  -- local continuation
  MkContName : String -> Int -> ContName
  -- closure as continuation
  Clos : Name -> ContName

export
Show ContName where
  show (BigK i) = "[K\{show i}]"
  show (MkContName s i) = "K:" ++ s ++ ":" ++ (show i)
  show (Clos n) = "K:<\{show n}>"

Eq ContName where
  (BigK i) == (BigK j) = (i == j)
  (MkContName s1 i1) == (MkContName s2 i2) =
    (i1 == i2) && (s1 == s2)
  (Clos n1) == (Clos n2) = n1 == n2
  _ == _ = False

mutual
  public export
  data CPSNmExp : Type where
    KNAppCont : FC -> (k : ContName) -> (arg : Atom) -> CPSNmExp
    KNJump : FC -> (f : Atom) -> (kn : ContName) -> (args : List Atom) -> CPSNmExp
    KNJump : FC -> (f : Atom) -> (k : ContName) -> (args : List Atom) -> CPSNmExp
    KNCon : FC -> Name -> ConInfo -> (tag : Maybe Int) -> List Atom -> Name -> CPSNmExp -> CPSNmExp
    KNExtPrim : FC -> Name -> List Atom -> Name -> CPSNmExp -> CPSNmExp
    KNFix : FC -> (k : ContName) -> (args : List Name) -> (body : CPSNmExp) -> Name -> CPSNmExp -> CPSNmExp


@@ 168,9 173,12 @@ cps nm (NmApp fc f args) kf = do
  resultVar <- getUnique' "res_"
  cont <- kf (KLocal emptyFC resultVar)
  contName <- makeContName "app"
  bk <- makeBigK
  contName' <- getUnique' "app"
  app <- cps nm f (\x =>
    argHelper nm args (\cargs => pure $ KNJump fc x contName cargs))
  pure (KNLetCont fc resultVar cont contName app)
    argHelper nm args (\cargs => pure $ KNJump fc x (Clos contName') cargs))
  --pure (KNLetCont fc resultVar cont contName app)
  pure (KNFix fc bk [resultVar] cont contName' app)

cps nm (NmCon fc n ci tag args) kf = do
  resultVar <- getUnique' "con_"


@@ 193,7 201,7 @@ cps nm (NmConstCase fc sc alts def) kf = do
  lambdaArg <- getUnique' "cca_"
  cbody <- kf (KLocal emptyFC lambdaArg)

  let contSmall = isSmallerThan 2 cbody
  let contSmall = False --isSmallerThan 2 cbody
  let usekf = if contSmall then kf else newkf

  cps nm sc (\csc => do


@@ 208,7 216,7 @@ cps nm (NmConCase fc sc alts def) kf = do
  lambdaArg <- getUnique' "cca_"
  cbody <- kf (KLocal emptyFC lambdaArg)

  let contSmall = isSmallerThan 5 cbody
  let contSmall = False --isSmallerThan 5 cbody
  let usekf = if contSmall then kf else newkf

  cps nm sc (\csc => do


@@ 224,11 232,13 @@ cps nm (NmExtPrim fc p args) kf = do
    pure $ KNExtPrim fc p cargs resultVar cont)
-- TODO: test:
cps nm (NmForce fc reason e) kf = do
  contName <- makeContName "tk"
  contName <- getUnique' "tk"
  bk <- makeBigK
  resultVar <- getUnique' "fres_"
  cont <- kf (KLocal emptyFC resultVar)
  cps nm e (\thunk => do
    pure $ KNLetCont fc resultVar cont contName $ KNJump fc thunk contName [])
    --pure $ KNLetCont fc resultVar cont contName $ KNJump fc thunk contName [])
    pure $ KNFix fc bk [resultVar] cont contName $ KNJump fc thunk (Clos contName) [])
-- TODO: test:
cps nm (NmDelay fc reason e) kf = do
  thunkCont <- makeBigK


@@ 341,7 351,7 @@ khalt arg = pure $ KNAppCont emptyFC (BigK 0) arg

toCPS : NamedCExp -> Core CPSNmExp
toCPS e = do
  cpsIdx <- newRef CPSIndex 0
  cpsIdx <- newRef CPSIndex 1
  cps empty e khalt

simplifyTwice : CPSNmExp -> CPSNmExp

M src/Compiler/CPSLift.idr => src/Compiler/CPSLift.idr +25 -5
@@ 19,6 19,8 @@ mutual
    KLJump : FC -> (f : Name) -> (args : List Atom) -> CPSLifted
    KLApp : FC -> (f : Atom) -> (args : List Atom) -> CPSLifted
    KLClosure : FC -> Name -> (missing : Nat) -> (free : List Atom) -> Name -> CPSLifted -> CPSLifted
    KLLocalCont : FC -> (arg : Name) -> (body : CPSLifted) -> ContName -> CPSLifted -> CPSLifted
    KLAppCont : FC -> ContName -> (arg : Atom) -> CPSLifted
    KLCon : FC -> Name -> ConInfo -> (tag : Maybe Int) -> List Atom -> Name -> CPSLifted -> CPSLifted
    KLExtPrim : FC -> Name -> List Atom -> Name -> CPSLifted -> CPSLifted
    KLOp : {arity : _} -> FC -> PrimFn arity -> Vect arity Atom -> Name -> CPSLifted -> CPSLifted


@@ 49,6 51,8 @@ mutual
  Show CPSLifted where
    show (KLJump _ n args) = assert_total "(KLJump \{show n} [\{show args}])"
    show (KLApp _ f args) = assert_total "(KLApp \{show f} [\{show args}])"
    show (KLLocalCont _ arg body kn k) = assert_total "(KLLocalCont [\{show arg}] => \{show body}, bind \{show kn} in \{show k})"
    show (KLAppCont _ kn arg) = assert_total "(KLAppCont \{show kn} [\{show arg}])"
    show (KLClosure _ fn miss args n k) = assert_total "(Closure \{show fn} (\{show args}), bind \{show n} in \{show k} )"
    show (KLCon _ name ci tag args n k) = assert_total "[%con \{show ci} \{show name} \{show tag} \{show args}, bind \{show n} in \{show k}]"
    show (KLExtPrim _ prim args n k) = assert_total "[%extern \{show prim} \{show args}, bind \{show n} in \{show k}]"


@@ 73,6 77,8 @@ mutual
  ishow : CPSLifted -> String
  ishow (KLJump _ n args) = assert_total "(KLJump \{show n} [\{show args}])"
  ishow (KLApp _ f args) = assert_total "(KLApp \{show f} [\{show args}])"
  ishow (KLLocalCont _ arg body kn k) = assert_total "(KLLocalCont (\{show arg}) => (\{ishow body}), bind \{show kn} in ... )"
  ishow (KLAppCont _ kn arg) = assert_total "(KLAppCont \{show kn} [\{show arg}])"
  ishow (KLClosure _ fn miss args n k) = assert_total "(Closure \{show fn} (\{show args}), bind \{show n} in ... )"
  ishow (KLCon _ name ci tag args n k) = assert_total "[%con \{show ci} \{show name} \{show tag} \{show args}, bind \{show n} in ... ]"
  ishow (KLExtPrim _ prim args n k) = assert_total "[%extern \{show prim} \{show args}, bind \{show n} in ... ]"


@@ 85,6 91,12 @@ mutual
SortedSet : Type -> Type
SortedSet k = SortedMap k ()

export
decontName : ContName -> Name
decontName (BigK i) = MN ("__cps_BIGK") i
decontName (MkContName s i) = MN ("__cps_" ++ s) i
decontName (Clos n) = n

findFreeAtom : Atom -> SortedSet Name
findFreeAtom (KLocal _ n) = singleton n ()
findFreeAtom _ = empty


@@ 98,6 110,15 @@ findFree (KLJump _ f args) =
  foldMap findFreeAtom args
findFree (KLApp _ f args) =
  findFreeAtom f <+> foldMap findFreeAtom args
findFree (KLAppCont _ kn@(MkContName _ _) arg) =
  -- ignore local continuation
  findFreeAtom arg
findFree (KLAppCont _ kn arg) =
  insert (decontName kn) () $ findFreeAtom arg
findFree (KLLocalCont _ arg body _ k) =
  let freeK = findFree k in
  let freeBody = delete arg (findFree body) in
  freeK <+> freeBody
findFree (KLClosure _ _ _ args n k) =
  let freeK = delete n (findFree k) in
  freeK <+> foldMap findFreeAtom args


@@ 145,10 166,6 @@ emitFunc fc free args body = do
  put Lifted ((n+1), f::funcs)
  pure newName

decontName : ContName -> Name
decontName (BigK i) = MN ("__cps_BIGK") i
decontName (MkContName s i) = MN ("__cps_" ++ s) i

liftCPS : {auto lif : Ref Lifted LiftedState} ->
          CPSNmExp ->
          Core CPSLifted


@@ 163,9 180,12 @@ liftCPS (KNLetCont fc arg body kn k) = do
  lbody <- liftCPS body
  lk <- liftCPS k
  let rawKn = decontName kn
  {-
  let freeVars = keys $ delete rawKn $ delete arg (findFree lbody)
  liftedFuncName <- emitFunc fc freeVars [arg] lbody
  pure (KLClosure fc liftedFuncName 1 (map (KLocal emptyFC) freeVars) rawKn lk)
  -}
  pure (KLLocalCont fc arg lbody kn lk)
liftCPS (KNCon fc name ci tag args n k) = do
  lk <- liftCPS k
  pure (KLCon fc name ci tag args n lk)


@@ 202,7 222,7 @@ liftCPS (KNJump fc (KLocal fc' n) kn args) = do
liftCPS (KNJump fc x _ _) = do
  pure (KLCrash fc "cannot call \{show x}")
liftCPS (KNAppCont fc kn arg) = do
  pure (KLApp fc (KLocal emptyFC (decontName kn)) [arg])
  pure (KLAppCont fc kn arg)
liftCPS (KNLetConst fc c n k) = do
  lk <- liftCPS k
  pure (KLLetConst fc c n lk)

M src/Compiler/LLVM/Rapid/CPSCodegen.idr => src/Compiler/LLVM/Rapid/CPSCodegen.idr +38 -0
@@ 692,6 692,7 @@ getInstIR instr@(KLClosure fc fn missing free n k) = do
  appendCode "; \{ishow instr}"
  mkClosure fn n (cast missing) free
  getInstIR k

getInstIR instr@(KLJump fc fn args) =
  if length args <= MAX_ARGS_DIRECT
  then do


@@ 709,6 710,40 @@ getInstIR instr@(KLJump fc fn args) =
    let iicName = MN "__iic" !(getUnique)
    appendCode "%\{safeName dummyArgName} = bitcast i8* null to %ObjPtr\n"
    getInstIR (KLClosure fc fn 1 args iicName (KLApp fc (KLocal fc iicName) [KLocal fc dummyArgName]))
getInstIR instr@(KLLocalCont fc arg body kn k) = do
  let lbl = "k" ++ (safeName $ decontName kn)
  appendCode "; prepare arg for \{lbl}:"
  appendCode "%\{lbl}.arg0 = alloca ptr"
  getInstIR k
  appendCode "; \{ishow instr}"
  appendCode "\{lbl}:"
  appendCode "; load arg for \{lbl}:"
  appendCode "%\{safeName arg} = load ptr, ptr %\{lbl}.arg0"
  getInstIR body
getInstIR instr@(KLAppCont fc (BigK n) arg) = do
  appendCode "; \{ishow instr}"
  heapOffset <- getHeapOffset
  when (heapOffset > 0) (addError "more heap allocated than checked")
  let heapCheck = SSA I64 "%HeapCheck"
  heapNext <- mkSub heapCheck (pConst (-heapOffset))
  appendCode $ "musttail call tailcc void @__apply_closure1(" ++ (toIR heapNext) ++ ", %TSOPtr %BaseArg, ptr %\{safeName $ decontName (BigK n)}, " ++ (toIR $ !(atom2val arg)) ++ ")"
  appendCode "ret void"
getInstIR instr@(KLAppCont fc (Clos n) arg) = do
  appendCode "; \{ishow instr}"
  heapOffset <- getHeapOffset
  when (heapOffset > 0) (addError "more heap allocated than checked")
  let heapCheck = SSA I64 "%HeapCheck"
  heapNext <- mkSub heapCheck (pConst (-heapOffset))
  appendCode $ "musttail call tailcc void @__apply_closure1(" ++ (toIR heapNext) ++ ", %TSOPtr %BaseArg, ptr %\{safeName n}, " ++ (toIR $ !(atom2val arg)) ++ ")"
  appendCode "ret void"
getInstIR instr@(KLAppCont fc kn arg) = do
  appendCode "; \{ishow instr}"
  heapOffset <- getHeapOffset
  when (heapOffset > 0) (addError "more heap allocated than checked")
  let heapCheck = SSA I64 "%HeapCheck"
  heapNext <- mkSub heapCheck (pConst (-heapOffset))
  appendCode $ "store ptr %k\{safeName $ decontName kn}.arg0, " ++ toIR !(atom2val arg)
  appendCode $ "br label %k\{safeName $ decontName kn}"
getInstIR instr@(KLApp fc f [arg1]) = do
  appendCode "; \{ishow instr}"
  heapOffset <- getHeapOffset


@@ 1002,6 1037,9 @@ maxRequiredHeapSpace (KLJump _ _ args) = let argCount = length args in
  if argCount <= MAX_ARGS_DIRECT then 0
    else 16 + 8 * (cast {to=Int} argCount)
maxRequiredHeapSpace (KLApp _ _ _) = 0
maxRequiredHeapSpace (KLAppCont _ _ _) = 0
maxRequiredHeapSpace (KLLocalCont _ _ body _ k) =
  maxRequiredHeapSpace body + maxRequiredHeapSpace k
maxRequiredHeapSpace (KLClosure _ _ _ args _ k) = 16 + 8 * (cast {to=Int} $ length args) + maxRequiredHeapSpace k
maxRequiredHeapSpace (KLCon _ _ _ _ args _ k) = 8 + 8 * (cast {to=Int} $ length args) + maxRequiredHeapSpace k
maxRequiredHeapSpace (KLExtPrim _ (NS _ primName) _ _ k) =

M src/Compiler/PrepareCode.idr => src/Compiler/PrepareCode.idr +2 -0
@@ 16,7 16,9 @@ constructorNamesConstAlt : CLConstAlt -> List Name
collectConstructorNames : CPSLifted -> List Name
collectConstructorNames (KLJump _ _ _) = []
collectConstructorNames (KLApp _ _ _) = []
collectConstructorNames (KLAppCont _ _ _) = []
collectConstructorNames (KLClosure _ _ _ _ _ k) = collectConstructorNames k
collectConstructorNames (KLLocalCont _ _ body _ k) = collectConstructorNames body ++ collectConstructorNames k
collectConstructorNames (KLCon _ name _ (Just tag) _ _ k) = collectConstructorNames k
collectConstructorNames (KLCon _ name _ Nothing _ _ k) = name :: collectConstructorNames k
collectConstructorNames (KLExtPrim _ _ _ _ k) = collectConstructorNames k