@@ 77,7 77,7 @@ data Statement'
type Statement = Branch Statement'
data Terminator'
- = TRetVal Operand
+ = TRetVal Expr
| TRetVoid
deriving Show
@@ 133,29 133,58 @@ lower (Program (Topo defs) datas externs) =
capturesName <- newLName "captures"
let rt' = lowerType rt
let paramTs = map (\t -> if passByRef t then Low.TPtr t else t) directParamTs
- (outerParamIds, body'') <- withVars binds $ case rt' of
- ZeroSized -> do
- body' <- lowerExpr VoidDest body
- let isTailRec = isBranchTailRec (last (Low.blockStms body')) $ \case
- Low.VoidCall (Low.OGlobal other) _ | other == self -> True
- _ -> False
- if isTailRec
+ let innerParams = zipWith Low.Local innerParamIds paramTs
+ -- Lower the body, generate an out-parameter if the return value is to be passed
+ -- on the stack, and optimize to loop if the function is tail recursive.
+ (outParam, outerParamIds, body'') <- do
+ -- These will be discarded if the function is not tail recursive. In that
+ -- case, the inner params and the outer params are the same.
+ outerParamIds <- mapM spinoffLocalId innerParamIds
+ let outerParams = zipWith Low.Local outerParamIds paramTs
+ withVars binds $ case rt' of
+ ZeroSized -> do
+ body' <- lowerExpr VoidDest body
+ if isTailRec_RetVoid self body'
+ then fmap
+ (Nothing, outerParamIds, )
+ (tailCallOpt_RetVoid self outerParams innerParams body')
+ else pure
+ ( Nothing
+ , innerParamIds
+ , mapTerm (\() -> Low.BLeaf Low.TRetVoid) body'
+ )
+ Sized t -> if passByRef t
then do
- outerParamIds <- mapM spinoffLocalId innerParamIds
- let innerParams = zipWith Low.Local innerParamIds paramTs
- let outerParams = zipWith Low.Local outerParamIds paramTs
- fmap (outerParamIds, )
- (tailCallOptZeroSized self outerParams innerParams body')
- else pure
- (innerParamIds, mapTerm (\() -> Low.BLeaf Low.TRetVoid) body')
- Sized t -> if passByRef t
- then undefined (lowerExpr (Address undefined) body)
- else undefined (lowerExpr Register body)
+ outParamId <- newLName "sret"
+ let outParamOp = Low.OLocal $ Low.Local outParamId (Low.TPtr t)
+ let outParam = Just $ Low.ByRef outParamId t
+ body' <- lowerExpr (Address outParamOp) body
+ if isTailRec_RetVoid self body'
+ then fmap
+ (outParam, outerParamIds, )
+ (tailCallOpt_RetVoid self outerParams innerParams body')
+ else pure
+ ( outParam
+ , innerParamIds
+ , mapTerm (\() -> Low.BLeaf Low.TRetVoid) body'
+ )
+ else do
+ body' <- lowerExpr Register body
+ if isTailRec_RetVal self body'
+ then fmap
+ (Nothing, outerParamIds, )
+ (tailCallOpt_RetVal self outerParams innerParams body')
+ else pure
+ ( Nothing
+ , innerParamIds
+ , mapTerm (Low.BLeaf . Low.TRetVal) body'
+ )
localNames <- popLocalNames
allocs <- popAllocs
pure $ Low.FunDef
name
- (Low.ByVal capturesName Low.VoidPtr
+ (maybe id (:) outParam
+ $ Low.ByVal capturesName Low.VoidPtr
: zipWith sizedToParam outerParamIds directParamTs
)
(undefined rt')
@@ 163,8 192,16 @@ lower (Program (Topo defs) datas externs) =
allocs
localNames
- tailCallOptZeroSized self outerParams innerParams body = do
- let (bodyStms, lastStm) = fromJust (unsnoc (Low.blockStms body))
+ isTailRec_RetVoid self (Low.Block stms ()) = isBranchTailRec (last stms) $ \case
+ Low.VoidCall (Low.OGlobal other) _ | other == self -> True
+ _ -> False
+
+ isTailRec_RetVal self (Low.Block _ (Low.Expr e _)) = isBranchTailRec e $ \case
+ Low.Call (Low.OGlobal other) _ | other == self -> True
+ _ -> False
+
+ tailCallOpt_RetVoid self outerParams innerParams (Low.Block stms ()) = do
+ let (bodyStms, lastStm) = fromJust (unsnoc stms)
let loopTermBlock = tailCallOptBranch lastStm $ \case
Low.VoidCall (Low.OGlobal other) args | other == self ->
Low.Block [] (Low.Continue args)
@@ 174,6 211,19 @@ lower (Program (Topo defs) datas externs) =
let loop = Low.Loop loopParams loopInner
pure $ Low.Block [Low.BLeaf (Low.SLoop loop)] (Low.BLeaf Low.TRetVoid)
+ tailCallOpt_RetVal self outerParams innerParams body = do
+ let (Low.Block bodyStms (Low.Expr lastExpr t)) = body
+ let loopTermBlock = tailCallOptBranch lastExpr $ \case
+ Low.Call (Low.OGlobal other) args | other == self ->
+ Low.Block [] (Low.Continue args)
+ e -> Low.Block [] (Low.Break (Low.Expr (Low.BLeaf e) t))
+ let loopInner = Low.Block bodyStms () `thenBlock` loopTermBlock
+ let loopParams = zip innerParams (map Low.OLocal outerParams)
+ let loop = Low.Loop loopParams loopInner
+ pure $ Low.Block
+ []
+ (Low.BLeaf (Low.TRetVal (Low.Expr (Low.BLeaf (Low.ELoop loop)) t)))
+
isBranchTailRec br f = case br of
Low.BLeaf x -> f x
Low.BIf _ (Low.Block _ y1) (Low.Block _ y2) ->
@@ 182,8 232,8 @@ lower (Program (Topo defs) datas externs) =
any (flip isBranchTailRec f . Low.blockTerm . snd) cs || isBranchTailRec d f
tailCallOptBranch
- :: Low.Statement
- -> (Low.Statement' -> Low.Block (Low.LoopTerminator' t))
+ :: Low.Branch a
+ -> (a -> Low.Block (Low.LoopTerminator' t))
-> Low.Block (Low.LoopTerminator t)
tailCallOptBranch br f = case br of
Low.BLeaf x -> mapTerm Low.BLeaf (f x)