~jojo/Carth

ca7366cf3ffa0fe146fcc5c3899ac7f83e34faf6 — JoJo 1 year, 10 months ago b32b4f4
Make passByRef more exhaustive and correct

I think it's pretty close to System V ABI spec now.

Sadly, access to env of datatypes is needed to do lookup for
NamedTypeReference, and this had something of a chain reaction. Not
only passByRef became Gen', but also toLlvmType! This had some effect
throughout the module.
1 files changed, 136 insertions(+), 85 deletions(-)

M src/Codegen.hs
M src/Codegen.hs => src/Codegen.hs +136 -85
@@ 58,6 58,7 @@ import Data.Word
import Data.Foldable
import Data.List
import Data.Composition
import Data.Functor
import Control.Applicative
import Control.Lens
    ( makeLenses


@@ 93,6 94,7 @@ data Val

data Env = Env
    { _env :: Map TypedVar Operand  -- ^ Environment of stack allocated variables
    , _dataTypes :: Map Name Type
    , _dataLayout :: DataLayout
    }
makeLenses ''Env


@@ 142,20 144,49 @@ instance Typed Val where

codegen :: DataLayout -> FilePath -> Program -> EncodeAST Module
codegen layout moduleFilePath (Program defs tdefs externs) = do
    tdefs' <- defineDataTypes layout tdefs
    let defs' = Map.toList defs
        genGlobDefs = withExternSigs externs $ withGlobDefSigs
            defs'
            (liftA2 (:) genMain (fmap join (mapM genGlobDef defs')))
    globDefs <- runGen' layout genGlobDefs
    (tdefs', externs', globDefs) <- runGen' layout $ do
        tdefs'' <- defineDataTypes tdefs
        withDataTypes tdefs''
            $ withExternSigs externs
            $ withGlobDefSigs defs'
            $ do
                es <- genExterns externs
                ds <- liftA2 (:) genMain (fmap join (mapM genGlobDef defs'))
                pure (tdefs'', es, ds)
    pure Module
        { moduleName = fromString ((takeBaseName moduleFilePath))
        , moduleSourceFileName = fromString moduleFilePath
        , moduleDataLayout = Just layout
        , moduleTargetTriple = Nothing
        , moduleDefinitions =
            tdefs' ++ genBuiltins ++ genExterns externs ++ globDefs
        , moduleDefinitions = concat
            [ map (\(n, tmax) -> TypeDefinition n (Just tmax)) tdefs'
            , genBuiltins
            , externs'
            , globDefs
            ]
        }
  where
    withDataTypes = augment dataTypes . Map.fromList
    withExternSigs es ga = do
        es' <- forM es $ \(name, t) -> do
            t' <- toLlvmType' t
            pure
                ( TypedVar name t
                , ConstantOperand
                    $ LLConst.GlobalReference (LLType.ptr t') (mkName name)
                )
        augment env (Map.fromList es') ga
    withGlobDefSigs sigs ga = do
        sigs' <- forM sigs $ \(v@(TypedVar x t), (us, _)) -> do
            t' <- toLlvmType' t
            pure
                ( v
                , ConstantOperand $ LLConst.GlobalReference
                    (LLType.ptr t')
                    (mkName (mangleName (x, us)))
                )
        augment env (Map.fromList sigs') ga

-- TODO: Consider separating sizeof calculations to a separate pass preceeding
--       Codegen, so that IO/EncodeAST may be limited to a more overviewable and


@@ 174,25 205,25 @@ codegen layout moduleFilePath (Program defs tdefs externs) = do
--   A data-type is a tagged union, and is represented in LLVM as a struct where
--   the first element is the variant-index as an i64, and the rest of the
--   elements are the field-types of the largest variant wrt allocation size.
defineDataTypes :: DataLayout -> TypeDefs -> EncodeAST [Definition]
defineDataTypes layout tds = do
defineDataTypes :: TypeDefs -> Gen' [(Name, Type)]
defineDataTypes tds = do
    -- Forward declare to allow for recursion and unordered defs
    lhss <- forM tds $ \(tc, _) -> do
        let n = mkName (mangleTConst tc)
        (lhs, n') <- createNamedType n
        defineType n n' lhs
        (lhs, n') <- lift (lift (createNamedType n))
        lift (lift (defineType n n' lhs))
        pure (n, lhs)
    forM (zip lhss tds) $ \((n, lhs), (_, vs)) -> do
        let ts = map toLlvmVariantType vs
        sizedTs <- mapM (\t -> fmap (, t) (sizeof layout t)) ts
        ts <- mapM toLlvmVariantType vs
        sizedTs <- mapM (\t -> fmap (, t) (sizeof' t)) ts
        let (_, tmax) = maximum sizedTs
        setNamedType lhs tmax
        pure (TypeDefinition n (Just tmax))
        lift (lift (setNamedType lhs tmax))
        pure (n, tmax)

runGen' :: DataLayout -> Gen' a -> EncodeAST a
runGen' layout g = runReaderT
    (evalStateT g initSt)
    Env { _env = Map.empty, _dataLayout = layout }
    Env { _env = Map.empty, _dataTypes = Map.empty, _dataLayout = layout }

initSt :: St
initSt = St


@@ 212,12 243,12 @@ genBuiltins = map
        (LLType.ptr typeUnit)
    ]

genExterns :: [(String, MonoAst.Type)] -> [Definition]
genExterns = map (uncurry genExtern)
genExterns :: [(String, MonoAst.Type)] -> Gen' [Definition]
genExterns = mapM (uncurry genExtern)

genExtern :: String -> MonoAst.Type -> Definition
genExtern name t =
    GlobalDefinition $ simpleGlobVar' (mkName name) (toLlvmType t) Nothing
genExtern :: String -> MonoAst.Type -> Gen' Definition
genExtern name t = toLlvmType' t
    <&> \t' -> GlobalDefinition $ simpleGlobVar' (mkName name) t' Nothing

genMain :: Gen' Definition
genMain = do


@@ 266,10 297,11 @@ genFunDef (name, fvs, ptv@(TypedVar px pt), body) = do
    assign currentBlockInstrs []
    ((rt, fParams), Out basicBlocks globStrings lambdaFuncs) <- runWriterT $ do
        (capturesParam, captureLocals) <- genExtractCaptures fvs
        let pt' = toLlvmType pt
        pt' <- toLlvmType pt
        px' <- newName px
        -- Load params according to calling convention
        let (withParam, pt'') = if passByRef pt'
        passParamByRef <- passByRef pt'
        let (withParam, pt'') = if passParamByRef
                then (withVar, LLType.ptr pt')
                else (withLocal, pt')
        let pRef = LocalReference pt'' px'


@@ 279,7 311,8 @@ genFunDef (name, fvs, ptv@(TypedVar px pt), body) = do
        let fParams' =
                [uncurry Parameter capturesParam [], Parameter pt'' px' []]
        -- Return result according to calling convention
        if passByRef rt'
        returnResultByRef <- passByRef rt'
        if returnResultByRef
            then do
                let out = (LLType.ptr rt', mkName "out")
                emit (store result (uncurry LocalReference out))


@@ 302,7 335,7 @@ genExtractCaptures fvs = do
    fmap (capturesParam, ) $ if null fvs
        then pure []
        else do
            let capturesType = typeCaptures fvs
            capturesType <- typeCaptures fvs
            capturesPtr <- emitAnon
                (bitcast capturesPtrGeneric (LLType.ptr capturesType))
            captures <- emitAnon (load capturesPtr)


@@ 321,7 354,7 @@ genExpr expr = do
        If p c a -> genIf p c a
        Fun p b -> assign lambdaParentFunc parent *> genLambda p b
        Let ds b -> genLet ds b
        Match e cs tbody -> genMatch e cs (toLlvmType tbody)
        Match e cs tbody -> genMatch e cs =<< toLlvmType tbody
        -- TODO: Currently, the desugar converts a constructor to a construction
        --       wrapped in a bunch of lambdas. This generates a lot of wasteful
        --       code. We could be smarter -- keep it as a constructor item until


@@ 338,13 371,16 @@ genExpr expr = do
toLlvmDataType :: MonoAst.TConst -> Type
toLlvmDataType = typeNamed . mangleTConst

toLlvmVariantType :: [MonoAst.Type] -> Type
toLlvmVariantType = typeStruct . (i64 :) . map toLlvmType
toLlvmVariantType :: [MonoAst.Type] -> Gen' Type
toLlvmVariantType = fmap (typeStruct . (i64 :)) . mapM toLlvmType'

toLlvmType :: MonoAst.Type -> Gen Type
toLlvmType = lift . toLlvmType'

-- | Convert to the LLVM representation of a type in an expression-context.
toLlvmType :: MonoAst.Type -> Type
toLlvmType = \case
    TPrim tc -> case tc of
toLlvmType' :: MonoAst.Type -> Gen' Type
toLlvmType' = \case
    TPrim tc -> pure $ case tc of
        TUnit -> typeUnit
        TNat8 -> i8
        TNat16 -> i16


@@ 358,8 394,8 @@ toLlvmType = \case
        TChar -> i32
        TBool -> typeBool
    TFun a r -> toLlvmClosureType a r
    TBox t -> LLType.ptr (toLlvmType t)
    TConst t -> typeNamed (mangleTConst t)
    TBox t -> fmap LLType.ptr (toLlvmType' t)
    TConst t -> pure $ typeNamed (mangleTConst t)

-- | A `Fun` is a closure, and follows a certain calling convention
--


@@ 370,25 406,26 @@ toLlvmType = \case
--
--   An argument of a structure-type is passed by reference, to be compatible
--   with the C calling convention.
toLlvmClosureType :: MonoAst.Type -> MonoAst.Type -> Type
toLlvmClosureType a r =
    typeStruct [LLType.ptr typeUnit, LLType.ptr (toLlvmClosureFunType a r)]
toLlvmClosureType :: MonoAst.Type -> MonoAst.Type -> Gen' Type
toLlvmClosureType a r = toLlvmClosureFunType a r
    <&> \c -> typeStruct [LLType.ptr typeUnit, LLType.ptr c]

-- The type of the function itself within the closure
toLlvmClosureFunType :: MonoAst.Type -> MonoAst.Type -> Type
toLlvmClosureFunType a r =
    let
        a' = toLlvmType a
        a'' = if passByRef a' then LLType.ptr a' else a'
        r' = toLlvmType r
    in if passByRef r'
toLlvmClosureFunType :: MonoAst.Type -> MonoAst.Type -> Gen' Type
toLlvmClosureFunType a r = do
    a' <- toLlvmType' a
    r' <- toLlvmType' r
    passArgByRef <- passByRef' a'
    let a'' = if passArgByRef then LLType.ptr a' else a'
    returnResultByRef <- passByRef' r'
    pure $ if returnResultByRef
        then FunctionType
            { resultType = LLType.void
            , argumentTypes = [LLType.ptr r', LLType.ptr typeUnit, a'']
            , isVarArg = False
            }
        else FunctionType
            { resultType = toLlvmType r
            { resultType = r'
            , argumentTypes = [LLType.ptr typeUnit, a'']
            , isVarArg = False
            }


@@ 435,7 472,7 @@ genApp fe ae rt = genApp' (fe, [(ae, rt)])
        (fe, aes) -> do
            closure <- genExpr fe
            as <- mapM
                (\(ae, rt) -> fmap (, toLlvmType rt) (genExpr ae))
                (\(ae, rt) -> liftA2 (,) (genExpr ae) (toLlvmType rt))
                aes
            foldlM (\f (a, rt) -> app f a rt) closure as



@@ 444,9 481,11 @@ app closure a rt = do
    closure' <- getLocal closure
    captures <- emitReg' "captures" (extractvalue closure' [0])
    f <- emitReg' "function" (extractvalue closure' [1])
    a' <- if passByRef (typeOf a) then getVar a else getLocal a
    passArgByRef <- passByRef (typeOf a)
    a' <- if passArgByRef then getVar a else getLocal a
    let args = [(captures, []), (a', [])]
    if passByRef rt
    returnByRef <- passByRef rt
    if returnByRef
        then do
            out <- emitReg' "out" (alloca rt)
            emit'' $ call f ((out, [SRet]) : args)


@@ 488,7 527,9 @@ genIf pred conseq alt = do
genLet :: Defs -> Expr -> Gen Val
genLet ds b = do
    let (vs, es) = unzip (Map.toList ds)
    ps <- mapM (\(TypedVar n t) -> emitReg' n (alloca (toLlvmType t))) vs
    ps <- forM vs $ \(TypedVar n t) -> do
        t' <- toLlvmType t
        emitReg' n (alloca t')
    withVars (zip vs ps) $ do
        forM_ (zip ps es) $ \(p, (_, e)) -> do
            x <- getLocal =<< genExpr e


@@ 535,7 576,7 @@ genDecisionLeaf (bs, e) selections =

genAs :: [MonoAst.Type] -> Operand -> Gen Operand
genAs ts matchee = do
    let tvariant = toLlvmVariantType ts
    tvariant <- lift (toLlvmVariantType ts)
    let tgeneric = typeOf matchee
    pGeneric <- emitReg' "ction_ptr_generic" (alloca tgeneric)
    emit (store matchee pGeneric)


@@ 575,7 616,8 @@ genLambda p@(TypedVar px pt) (b, bt) = do
        Just s ->
            fmap (mkName . ((s ++ "_func_") ++) . show) (outerLambdaN <<+= 1)
        Nothing -> newName "func"
    let ft = toLlvmClosureFunType pt bt
    ft <- lift (toLlvmClosureFunType pt bt)
    let
        f = VLocal $ ConstantOperand $ LLConst.GlobalReference
            (LLType.ptr ft)
            fname


@@ 608,12 650,8 @@ genBox' x = do

genHeapAlloc :: Type -> Gen Operand
genHeapAlloc t = do
    size <- genSizeof t
    size <- fmap litU64' (lift (sizeof' t))
    emitAnon (callExtern "carth_alloc" (LLType.ptr typeUnit) [size])
  where
    genSizeof t = do
        layout <- view dataLayout
        fmap litU64' (lift (lift (lift (sizeof layout t))))

genDeref :: Expr -> Gen Val
genDeref e = genExpr e >>= \case


@@ 685,28 723,6 @@ getLocal = \case
    VVar x -> emitAnon (load x)
    VLocal x -> pure x

withExternSigs :: MonadReader Env m => [(String, MonoAst.Type)] -> m a -> m a
withExternSigs = augment env . Map.fromList . map
    (\(name, t) ->
        ( TypedVar name t
        , ConstantOperand
            (LLConst.GlobalReference (LLType.ptr (toLlvmType t)) (mkName name))
        )
    )

withGlobDefSigs
    :: MonadReader Env m => [(TypedVar, ([MonoAst.Type], Expr))] -> m a -> m a
withGlobDefSigs = augment env . Map.fromList . map
    (\(v@(TypedVar x t), (us, _)) ->
        ( v
        , ConstantOperand
            (LLConst.GlobalReference
                (LLType.ptr (toLlvmType t))
                (mkName (mangleName (x, us)))
            )
        )
    )

withLocals :: [(TypedVar, Operand)] -> Gen a -> Gen a
withLocals = flip (foldr (uncurry withLocal))



@@ 911,15 927,45 @@ litStructOfType t xs =
litUnit :: LLConst.Constant
litUnit = litStruct []

passByRef :: Type -> Bool
passByRef = \case
    LLType.StructureType _ [] -> False
    LLType.StructureType _ _ -> True
    LLType.NamedTypeReference _ -> True
    _ -> False
passByRef :: Type -> Gen Bool
passByRef = lift . passByRef'

typeCaptures :: [TypedVar] -> Type
typeCaptures = typeStruct . map (\(TypedVar _ t) -> toLlvmType t)
-- TODO: Handle >64bit integers and pointers. Haven't checked the rules for
--       those.
--
-- NOTE: This post is helpful:
--       https://stackoverflow.com/questions/42411819/c-on-x86-64-when-are-structs-classes-passed-and-returned-in-registers
--       Also, official docs:
--       https://software.intel.com/sites/default/files/article/402129/mpx-linux64-abi.pdf
--       particularly section 3.2.3 Parameter Passing (p18).
passByRef' :: Type -> Gen' Bool
passByRef' t = case t of
    NamedTypeReference x -> passByRef' =<< views dataTypes (Map.! x)
    -- Simple scalar types. They go in registers.
    VoidType -> pure False
    IntegerType _ -> pure False
    PointerType _ _ -> pure False
    FloatingPointType _ -> pure False
    -- Functions are not POD (Plain Ol' Data), so they are passed on the stack.
    FunctionType _ _ _ -> pure True
    -- TODO: Investigate how exactly SIMD vectors are to be passed when/if we
    --       ever add support for that in the rest of the compiler.
    VectorType _ _ -> pure True
    -- Aggregate types can either be passed on stack or in regs, depending on
    -- what they contain.
    StructureType _ us -> do
        size <- sizeof' t
        if size > 16 then pure True else fmap or (mapM passByRef' us)
    ArrayType _ u -> do
        size <- sizeof' u
        if size > 16 then pure True else passByRef' u
    -- N/A
    MetadataType -> ice "passByRef of MetadataType"
    LabelType -> ice "passByRef of LabelType"
    TokenType -> ice "passByRef of TokenType"

typeCaptures :: [TypedVar] -> Gen Type
typeCaptures = fmap typeStruct . mapM (\(TypedVar _ t) -> toLlvmType t)

typeNamed :: String -> Type
typeNamed = NamedTypeReference . mkName


@@ 969,6 1015,11 @@ mangleType = \case
mangleTConst :: TConst -> String
mangleTConst (c, ts) = c ++ mangleInst ts

sizeof' :: Type -> Gen' Word64
sizeof' t = do
    layout <- view dataLayout
    lift (lift (sizeof layout t))

sizeof :: DataLayout -> Type -> EncodeAST Word64
sizeof layout t = do
    t' <- toFFIType t