~jojo/Carth

bff74e256276bc58dea5792b8da5a7331b5a1b72 — JoJo a month ago 0dccc6a
lower: Finish impl of tail recursion to loop conversion in Lower

Yeah, I think this was all of it actually? Sygytt!
2 files changed, 73 insertions(+), 23 deletions(-)

M src/Back/Low.hs
M src/Back/Lower.hs
M src/Back/Low.hs => src/Back/Low.hs +1 -1
@@ 77,7 77,7 @@ data Statement'
type Statement = Branch Statement'

data Terminator'
    = TRetVal Operand
    = TRetVal Expr
    | TRetVoid
    deriving Show


M src/Back/Lower.hs => src/Back/Lower.hs +72 -22
@@ 133,29 133,58 @@ lower (Program (Topo defs) datas externs) =
        capturesName <- newLName "captures"
        let rt' = lowerType rt
        let paramTs = map (\t -> if passByRef t then Low.TPtr t else t) directParamTs
        (outerParamIds, body'') <- withVars binds $ case rt' of
            ZeroSized -> do
                body' <- lowerExpr VoidDest body
                let isTailRec = isBranchTailRec (last (Low.blockStms body')) $ \case
                        Low.VoidCall (Low.OGlobal other) _ | other == self -> True
                        _ -> False
                if isTailRec
        let innerParams = zipWith Low.Local innerParamIds paramTs
        -- Lower the body, generate an out-parameter if the return value is to be passed
        -- on the stack, and optimize to loop if the function is tail recursive.
        (outParam, outerParamIds, body'') <- do
            -- These will be discarded if the function is not tail recursive. In that
            -- case, the inner params and the outer params are the same.
            outerParamIds <- mapM spinoffLocalId innerParamIds
            let outerParams = zipWith Low.Local outerParamIds paramTs
            withVars binds $ case rt' of
                ZeroSized -> do
                    body' <- lowerExpr VoidDest body
                    if isTailRec_RetVoid self body'
                        then fmap
                            (Nothing, outerParamIds, )
                            (tailCallOpt_RetVoid self outerParams innerParams body')
                        else pure
                            ( Nothing
                            , innerParamIds
                            , mapTerm (\() -> Low.BLeaf Low.TRetVoid) body'
                            )
                Sized t -> if passByRef t
                    then do
                        outerParamIds <- mapM spinoffLocalId innerParamIds
                        let innerParams = zipWith Low.Local innerParamIds paramTs
                        let outerParams = zipWith Low.Local outerParamIds paramTs
                        fmap (outerParamIds, )
                             (tailCallOptZeroSized self outerParams innerParams body')
                    else pure
                        (innerParamIds, mapTerm (\() -> Low.BLeaf Low.TRetVoid) body')
            Sized t -> if passByRef t
                then undefined (lowerExpr (Address undefined) body)
                else undefined (lowerExpr Register body)
                        outParamId <- newLName "sret"
                        let outParamOp = Low.OLocal $ Low.Local outParamId (Low.TPtr t)
                        let outParam = Just $ Low.ByRef outParamId t
                        body' <- lowerExpr (Address outParamOp) body
                        if isTailRec_RetVoid self body'
                            then fmap
                                (outParam, outerParamIds, )
                                (tailCallOpt_RetVoid self outerParams innerParams body')
                            else pure
                                ( outParam
                                , innerParamIds
                                , mapTerm (\() -> Low.BLeaf Low.TRetVoid) body'
                                )
                    else do
                        body' <- lowerExpr Register body
                        if isTailRec_RetVal self body'
                            then fmap
                                (Nothing, outerParamIds, )
                                (tailCallOpt_RetVal self outerParams innerParams body')
                            else pure
                                ( Nothing
                                , innerParamIds
                                , mapTerm (Low.BLeaf . Low.TRetVal) body'
                                )
        localNames <- popLocalNames
        allocs <- popAllocs
        pure $ Low.FunDef
            name
            (Low.ByVal capturesName Low.VoidPtr
            (maybe id (:) outParam
            $ Low.ByVal capturesName Low.VoidPtr
            : zipWith sizedToParam outerParamIds directParamTs
            )
            (undefined rt')


@@ 163,8 192,16 @@ lower (Program (Topo defs) datas externs) =
            allocs
            localNames

    tailCallOptZeroSized self outerParams innerParams body = do
        let (bodyStms, lastStm) = fromJust (unsnoc (Low.blockStms body))
    isTailRec_RetVoid self (Low.Block stms ()) = isBranchTailRec (last stms) $ \case
        Low.VoidCall (Low.OGlobal other) _ | other == self -> True
        _ -> False

    isTailRec_RetVal self (Low.Block _ (Low.Expr e _)) = isBranchTailRec e $ \case
        Low.Call (Low.OGlobal other) _ | other == self -> True
        _ -> False

    tailCallOpt_RetVoid self outerParams innerParams (Low.Block stms ()) = do
        let (bodyStms, lastStm) = fromJust (unsnoc stms)
        let loopTermBlock = tailCallOptBranch lastStm $ \case
                Low.VoidCall (Low.OGlobal other) args | other == self ->
                    Low.Block [] (Low.Continue args)


@@ 174,6 211,19 @@ lower (Program (Topo defs) datas externs) =
        let loop = Low.Loop loopParams loopInner
        pure $ Low.Block [Low.BLeaf (Low.SLoop loop)] (Low.BLeaf Low.TRetVoid)

    tailCallOpt_RetVal self outerParams innerParams body = do
        let (Low.Block bodyStms (Low.Expr lastExpr t)) = body
        let loopTermBlock = tailCallOptBranch lastExpr $ \case
                Low.Call (Low.OGlobal other) args | other == self ->
                    Low.Block [] (Low.Continue args)
                e -> Low.Block [] (Low.Break (Low.Expr (Low.BLeaf e) t))
        let loopInner = Low.Block bodyStms () `thenBlock` loopTermBlock
        let loopParams = zip innerParams (map Low.OLocal outerParams)
        let loop = Low.Loop loopParams loopInner
        pure $ Low.Block
            []
            (Low.BLeaf (Low.TRetVal (Low.Expr (Low.BLeaf (Low.ELoop loop)) t)))

    isBranchTailRec br f = case br of
        Low.BLeaf x -> f x
        Low.BIf _ (Low.Block _ y1) (Low.Block _ y2) ->


@@ 182,8 232,8 @@ lower (Program (Topo defs) datas externs) =
            any (flip isBranchTailRec f . Low.blockTerm . snd) cs || isBranchTailRec d f

    tailCallOptBranch
        :: Low.Statement
        -> (Low.Statement' -> Low.Block (Low.LoopTerminator' t))
        :: Low.Branch a
        -> (a -> Low.Block (Low.LoopTerminator' t))
        -> Low.Block (Low.LoopTerminator t)
    tailCallOptBranch br f = case br of
        Low.BLeaf x -> mapTerm Low.BLeaf (f x)