~jojo/Carth

4787862839d89f42711d9e708a5af3e5f5f7311a — JoJo 3 months ago 5f5c7f7
Simplify Codegen by using classes for the tail-call stuff
1 files changed, 89 insertions(+), 157 deletions(-)

M src/Codegen.hs
M src/Codegen.hs => src/Codegen.hs +89 -157
@@ 1,4 1,4 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE DuplicateRecordFields, GADTs #-}

-- | Generation of LLVM IR code from our monomorphic AST.
module Codegen (codegen) where


@@ 190,7 190,7 @@ genMain = do
        emitDo (callIntern Nothing init_ [(null' typeGenericPtr, []), (litUnit, [])])
        iof <- lookupVar (TypedVar "main" mainType)
        f <- genIndexStruct iof [0]
        _ <- app (Just NoTail) f (VLocal litRealWorld)
        _ <- app' @Val f (VLocal litRealWorld)
        commitFinalFuncBlock (ret (litI32 0))
    pure (GlobalDefinition (externFunc (mkName "main") [] i32 basicBlocks []))



@@ 257,34 257,25 @@ genGlobFunDef (TypedVar v _, WithPos dpos (ts, (p, (body, rt)))) = do
    pure (GlobalDefinition closureDef : GlobalDefinition f : gs)

genTailExpr :: Expr -> Gen ()
genTailExpr (Expr pos expr) = locally srcPos (pos <|>) $ do
    parent <- use lambdaParentFunc <* assign lambdaParentFunc Nothing
    case expr of
        App f e _ -> genTailApp f e
        If p c a -> genTailIf p c a
        Let d b -> genTailLet d b
        Match e cs tbody -> genTailMatch e cs =<< genType tbody
        _ -> genTailReturn =<< case expr of
            Fun (p, b) -> assign lambdaParentFunc parent *> genExprLambda p b
            _ -> genExpr (Expr pos expr)

genTailReturn :: Val -> Gen ()
genTailReturn v = commitFinalFuncBlock . ret =<< getLocal v
genTailExpr = genExpr

genExpr :: Expr -> Gen Val
genExpr :: TailVal v => Expr -> Gen v
genExpr (Expr pos expr) = locally srcPos (pos <|>) $ do
    parent <- use lambdaParentFunc <* assign lambdaParentFunc Nothing
    case expr of
        Lit c -> genConst c
        Var x -> lookupVar x
        App f e _ -> genApp f e
        If p c a -> genIf p c a
        Fun (p, b) -> assign lambdaParentFunc parent *> genExprLambda p b
        Lit c -> propagate =<< genConst c
        Var x -> propagate =<< lookupVar x
        App f e _ -> genBetaReduceApp (f, [e])
        If p c a -> genExpr p >>= \p' -> genCondBr p' (genExpr c) (genExpr a)
        Fun (p, b) -> assign lambdaParentFunc parent *> genExprLambda p b >>= propagate
        Let d b -> genLet d b
        Match e cs tbody -> genMatch e cs =<< genType tbody
        Ction c -> genCtion c
        Sizeof t -> (VLocal . litI64 . fromIntegral) <$> ((lift . sizeof) =<< genType t)
        Absurd t -> fmap (VLocal . undef) (genType t)
        Match e cs _ -> genMatch e cs
        Ction c -> propagate =<< genCtion c
        Sizeof t ->
            propagate
                =<< (VLocal . litI64 . fromIntegral)
                <$> ((lift . sizeof) =<< genType t)
        Absurd t -> propagate =<< fmap (VLocal . undef) (genType t)

genExprLambda :: TypedVar -> (Expr, Ast.Type) -> Gen Val
genExprLambda p (b, bt) = do


@@ 305,76 296,60 @@ genStrLit s = do
    scribe outStrings [(var, s)]
    pure $ VVar $ ConstantOperand (LLConst.GlobalReference (LLType.ptr typeStr) var)

genTailApp :: Expr -> Expr -> Gen ()
genTailApp fe' ae' = genBetaReduceApp genTailExpr genTailReturn Tail (fe', [ae'])

genApp :: Expr -> Expr -> Gen Val
genApp fe' ae' = genBetaReduceApp genExpr pure NoTail (fe', [ae'])
class TailVal v where
    propagate :: Val -> Gen v
    app' :: Val -> Val -> Gen v
    converge :: Gen v -> [(Name, Gen v)] -> Gen v

instance TailVal Val where
    propagate = pure
    app' = app (Just NoTail)
    converge default' cs = do
        nextL <- newName "next"
        v <- liftA2 (,) (getLocal =<< default') (use currentBlockLabel)
        let genCase (l, mv) = do
                commitToNewBlock (br nextL) l
                liftA2 (,) (getLocal =<< mv) (use currentBlockLabel)
        vs <- mapM genCase cs
        commitToNewBlock (br nextL) nextL
        fmap VLocal (emitAnonReg (phi (v : vs)))

instance TailVal () where
    propagate v = commitFinalFuncBlock . ret =<< getLocal v
    app' f e = propagate =<< app (Just Tail) f e
    converge default' cs = do
        () <- default'
        forM_ cs $ \(l, gen) -> assign currentBlockLabel l *> gen

-- | Beta-reduction and closure application
genBetaReduceApp
    :: (Expr -> Gen a) -> (Val -> Gen a) -> TailCallKind -> (Expr, [Expr]) -> Gen a
genBetaReduceApp genExpr' returnMethod tail' applic = ask >>= \env -> case applic of
genBetaReduceApp :: TailVal v => (Expr, [Expr]) -> Gen v
genBetaReduceApp applic = ask >>= \env -> case applic of
    (Expr _ (Fun (p, (b, _))), ae : aes) -> do
        a <- genExpr ae
        withVal p a (genBetaReduceApp genExpr' returnMethod tail' (b, aes))
    (Expr _ (App fe ae _), aes) ->
        genBetaReduceApp genExpr' returnMethod tail' (fe, ae : aes)
    (fe, []) -> genExpr' fe
        withVal p a (genBetaReduceApp (b, aes))
    (Expr _ (App fe ae _), aes) -> genBetaReduceApp (fe, ae : aes)
    (fe, []) -> genExpr fe
    (Expr _ (Var x), aes) | isNothing (lookupVar' x env) ->
        returnMethod =<< genAppBuiltinVirtual x (map genExpr aes)
        propagate =<< genAppBuiltinVirtual x (map genExpr aes)
    (fe, aes) -> do
        f <- genExpr fe
        as <- mapM genExpr (init aes)
        closure <- foldlM (app (Just NoTail)) f as
        closure <- foldlM app' f as
        arg <- genExpr (last aes)
        returnMethod =<< app (Just tail') closure arg

genTailIf :: Expr -> Expr -> Expr -> Gen ()
genTailIf pred' conseq alt = do
    predV <- genExpr pred'
    genTailCondBr predV (genTailExpr conseq) (genTailExpr alt)

genIf :: Expr -> Expr -> Expr -> Gen Val
genIf pred' conseq alt = do
    predV <- genExpr pred'
    genCondBr predV (genExpr conseq) (genExpr alt)
        app' closure arg

genTailCondBr :: Val -> Gen () -> Gen () -> Gen ()
genTailCondBr predV genConseq genAlt = do
    predV' <- emitAnonReg . flip trunc i1 =<< getLocal predV
    conseqL <- newName "consequent"
    altL <- newName "alternative"
    commitToNewBlock (condbr predV' conseqL altL) conseqL
    genConseq
    assign currentBlockLabel altL
    genAlt

genCondBr :: Val -> Gen Val -> Gen Val -> Gen Val
genCondBr :: TailVal v => Val -> Gen v -> Gen v -> Gen v
genCondBr predV genConseq genAlt = do
    predV' <- emitAnonReg . flip trunc i1 =<< getLocal predV
    conseqL <- newName "consequent"
    altL <- newName "alternative"
    nextL <- newName "next"
    commitToNewBlock (condbr predV' conseqL altL) conseqL
    conseqV <- getLocal =<< genConseq
    fromConseqL <- use currentBlockLabel
    commitToNewBlock (br nextL) altL
    altV <- getLocal =<< genAlt
    fromAltL <- use currentBlockLabel
    commitToNewBlock (br nextL) nextL
    fmap VLocal (emitAnonReg (phi [(conseqV, fromConseqL), (altV, fromAltL)]))

genTailLet :: Def -> Expr -> Gen ()
genTailLet d = genLet' d . genTailExpr
    converge genConseq [(altL, genAlt)]

genLet :: Def -> Expr -> Gen Val
genLet d = genLet' d . genExpr

genLet' :: Def -> Gen a -> Gen a
genLet' def genBody = case def of
genLet :: TailVal v => Def -> Expr -> Gen v
genLet def body = case def of
    VarDef (lhs, WithPos pos (_, rhs)) ->
        genExpr (Expr (Just pos) rhs) >>= \rhs' -> withVal lhs rhs' genBody
        genExpr (Expr (Just pos) rhs) >>= \rhs' -> withVal lhs rhs' (genExpr body)
    RecDefs ds -> do
        (binds, cs) <- fmap unzip $ forM ds $ \case
            (lhs, WithPos _ (_, (p, (fb, fbt)))) -> do


@@ 388,88 363,45 @@ genLet' def genBody = case def of
                pure ((lhs, lam), (captures, fvXs))
        withVals binds $ do
            forM_ cs (uncurry populateCaptures)
            genBody

genTailMatch :: Expr -> DecisionTree -> Type -> Gen ()
genTailMatch m dt tbody = do
    m' <- genExpr m
    genTailDecisionTree tbody dt (newSelections m')
            (genExpr body)

genMatch :: Expr -> DecisionTree -> Type -> Gen Val
genMatch m dt tbody = do
genMatch :: TailVal v => Expr -> DecisionTree -> Gen v
genMatch m dt = do
    m' <- genExpr m
    genDecisionTree tbody dt (newSelections m')

genTailDecisionTree :: Type -> DecisionTree -> Selections Val -> Gen ()
genTailDecisionTree = genDecisionTree' genTailExpr genTailCondBr genTailCases

genDecisionTree :: Type -> DecisionTree -> Selections Val -> Gen Val
genDecisionTree = genDecisionTree' genExpr genCondBr genCases

genDecisionTree'
    :: (Expr -> Gen a)
    -> (Val -> Gen a -> Gen a -> Gen a)
    -> (Type -> Selections Val -> [Name] -> [DecisionTree] -> DecisionTree -> Gen a)
    -> Type
    -> DecisionTree
    -> Selections Val
    -> Gen a
genDecisionTree' genExpr' genCondBr' genCases' tbody =
    let genDecisionLeaf (bs, e) selections = do
            bs' <- selectVarBindings selections bs
            withVals bs' (genExpr' e)
    genDecisionTree dt (newSelections m')

        genDecisionSwitchIx selector cs def selections = do
            let (variantIxs, variantDts) = unzip (Map.toAscList cs)
            (m, selections') <- select selector selections
            mVariantIx <- getLocal =<< case typeOf m of
                IntegerType _ -> pure m
                _ -> genIndexStruct m [0]
            let ixBits = getIntBitWidth (typeOf mVariantIx)
            let litIxInt = LLConst.Int ixBits
            variantLs <- mapM (newName . (++ "_") . ("variant_" ++) . show) variantIxs
            defaultL <- newName "default"
            let dests' = zip (map litIxInt variantIxs) variantLs
            commitToNewBlock (switch mVariantIx defaultL dests') defaultL
            genCases' tbody selections' variantLs variantDts def

        genDecisionSwitchStr selector cs def selections = do
            (matchee, selections') <- select selector selections
            let cs' = Map.toAscList cs
            let genCase (s, dt) next = do
                    s' <- genStrLit s
                    isMatch <- genStrEq matchee s'
                    -- Do some wrapping to preserve effect order
                    pure $ genCondBr' isMatch (genDT dt selections') next
            join (foldrM genCase (genDT def selections') cs')

        genDT = \case
            Ast.DLeaf l -> genDecisionLeaf l
            Ast.DSwitch selector cs def -> genDecisionSwitchIx selector cs def
            Ast.DSwitchStr selector cs def -> genDecisionSwitchStr selector cs def
    in  genDT

genTailCases
    :: Type -> Selections Val -> [Name] -> [DecisionTree] -> DecisionTree -> Gen ()
genTailCases tbody selections variantLs variantDts def = do
    genTailDecisionTree tbody def selections
    forM_ (zip variantLs variantDts) $ \(l, dt) -> do
        assign currentBlockLabel l
        genTailDecisionTree tbody dt selections

genCases :: Type -> Selections Val -> [Name] -> [DecisionTree] -> DecisionTree -> Gen Val
genCases tbody selections variantLs variantDts def = do
    nextL <- newName "next"
    let genDT dt = liftA2 (,)
                          (getLocal =<< genDecisionTree tbody dt selections)
                          (use currentBlockLabel)
    v <- genDT def
    let genCase l dt = do
            commitToNewBlock (br nextL) l
            genDT dt
    vs <- zipWithM genCase variantLs variantDts
    commitToNewBlock (br nextL) nextL
    fmap VLocal (emitAnonReg (phi (v : vs)))
genDecisionTree :: TailVal v => DecisionTree -> Selections Val -> Gen v
genDecisionTree = \case
    Ast.DLeaf l -> genDecisionLeaf l
    Ast.DSwitch selector cs def -> genDecisionSwitchIx selector cs def
    Ast.DSwitchStr selector cs def -> genDecisionSwitchStr selector cs def
  where
    genDecisionLeaf (bs, e) selections = do
        bs' <- selectVarBindings selections bs
        withVals bs' (genExpr e)
    genDecisionSwitchIx selector cs def selections = do
        let (variantIxs, variantDts) = unzip (Map.toAscList cs)
        (m, selections') <- select selector selections
        mVariantIx <- getLocal =<< case typeOf m of
            IntegerType _ -> pure m
            _ -> genIndexStruct m [0]
        let ixBits = getIntBitWidth (typeOf mVariantIx)
        let litIxInt = LLConst.Int ixBits
        variantLs <- mapM (newName . (++ "_") . ("variant_" ++) . show) variantIxs
        defaultL <- newName "default"
        let dests' = zip (map litIxInt variantIxs) variantLs
        commitToNewBlock (switch mVariantIx defaultL dests') defaultL
        converge (genDecisionTree def selections')
                 (zip variantLs (map (flip genDecisionTree selections') variantDts))
    genDecisionSwitchStr selector cs def selections = do
        (matchee, selections') <- select selector selections
        let cs' = Map.toAscList cs
        let genCase (s, dt) next = do
                s' <- genStrLit s
                isMatch <- genStrEq matchee s'
                -- Do some wrapping to preserve effect order
                pure $ genCondBr isMatch (genDecisionTree dt selections') next
        join (foldrM genCase (genDecisionTree def selections') cs')

genCtion :: Ast.Ction -> Gen Val
genCtion (i, span', dataType, as) = do