~boringcactus/crowbar-reference-compiler

b6258d36b6534d521e9cdf1307665c38e5ae409d — Melody Horn 1 year, 6 months ago 9dfc552
compile based on the fancy new AST
M crowbar_reference_compiler/__init__.py => crowbar_reference_compiler/__init__.py +2 -6
@@ 1,10 1,7 @@
import dataclasses
from pprint import pprint

from .ast import build_ast
from .parser import parse_header, parse_implementation
from .scanner import scan
from .ssagen import compile_to_ssa
from .ssagen import build_ssa


def main():


@@ 37,9 34,8 @@ def main():
        return

    full_ast = build_ast(parse_tree, args.include_dir)
    pprint(dataclasses.asdict(full_ast))

    ssa = compile_to_ssa(parse_tree)
    ssa = build_ssa(full_ast)
    if args.stop_at_qbe_ssa:
        if args.out is None:
            args.out = args.input.replace('.cro', '.ssa')

M crowbar_reference_compiler/ast.py => crowbar_reference_compiler/ast.py +2 -1
@@ 530,7 530,8 @@ class ASTBuilder(NodeVisitor):
            return ConstantExpression(body.data)
        if body.type in ['true', 'false']:
            return ConstantExpression(body.type)
        raise NotImplementedError()
        if body.type == 'string_literal':
            return ConstantExpression(body.data)

    def visit_StructPointerElementSuffix(self, node, visited_children):
        separator, element = visited_children

M crowbar_reference_compiler/ssagen.py => crowbar_reference_compiler/ssagen.py +113 -124
@@ 1,124 1,113 @@
from parsimonious import NodeVisitor  # type: ignore
from parsimonious.nodes import Node  # type: ignore


class SsaGenVisitor(NodeVisitor):
    def __init__(self):
        self.data = []

    def visit_ImplementationFile(self, node, visited_children):
        data = '\n'.join(self.data)
        functions = '\n'.join(visited_children)
        return data + '\n' + functions

    def visit_IncludeStatement(self, node, visited_children):
        include, included_header, semicolon = visited_children
        assert include.text[0].type == 'include'
        assert included_header.type == 'string_literal'
        included_header = included_header.data
        assert semicolon.text[0].type == ';'
        return ''

    def visit_FunctionDefinition(self, node, visited_children):
        signature, body = visited_children
        return_type, name, args = signature
        body = '\n'.join('    ' + instr for instr in body)
        return f"export function w ${name}() {{\n@start\n{body}\n}}"

    def visit_FunctionSignature(self, node, visited_children):
        return_type, name, lparen, args, rparen = visited_children
        assert name.type == 'identifier'
        name = name.data
        assert lparen.text[0].type == '('
        assert rparen.text[0].type == ')'
        return return_type, name, args

    def visit_Block(self, node, visited_children):
        lbrace, statements, rbrace = visited_children
        return statements

    def visit_Statement(self, node, visited_children):
        return visited_children[0]

    def visit_ExpressionStatement(self, node, visited_children):
        expression, semicolon = visited_children
        assert semicolon.text[0].type == ';'
        return expression

    def visit_Expression(self, node, visited_children):
        # TODO handle logical and/or
        return visited_children[0]

    def visit_ComparisonExpression(self, node, visited_children):
        # TODO handle comparisons
        return visited_children[0]

    def visit_BitwiseOpExpression(self, node, visited_children):
        # TODO handle bitwise operations
        return visited_children[0]

    def visit_ArithmeticExpression(self, node, visited_children):
        # TODO handle addition/subtraction
        return visited_children[0]

    def visit_TermExpression(self, node, visited_children):
        # TODO handle multiplication/division/modulus
        return visited_children[0]

    def visit_FactorExpression(self, node, visited_children):
        # TODO handle casts/address-of/pointer-dereference/unary ops/sizeof
        return visited_children[0]

    def visit_ObjectExpression(self, node, visited_children):
        # TODO handle array literals
        # TODO handle struct literals
        base, suffices = visited_children[0]
        if isinstance(suffices, Node):
            suffices = suffices.children
        if len(suffices) == 0:
            return base
        if base.type == 'identifier' and suffices[0].text[0].type == '(':
            arguments = suffices[1]
            if arguments[0].type == 'string_literal':
                data = arguments[0].data
                name = f"$data{len(self.data)}"
                # TODO handle non-variadic functions
                arguments = [f"l {name}", '...']
                self.data.append(f"data {name} = {{ b {data}, b 0 }}")
            return f"call ${base.data}({', '.join(arguments)})"
        print(base)
        print(suffices[0])

    def visit_AtomicExpression(self, node, visited_children):
        # TODO handle parenthesized subexpressions
        return visited_children[0]

    def visit_FlowControlStatement(self, node, visited_children):
        # TODO handle break/continue
        ret, arg, semicolon = visited_children[0]
        assert ret.text[0].type == 'return'
        assert semicolon.text[0].type == ';'
        if arg.type == 'constant':
            return f"ret {arg.data}"

    def visit_constant(self, node, visited_children):
        return node.text[0]

    def visit_string_literal(self, node, visited_children):
        return node.text[0]

    def visit_identifier(self, node, visited_children):
        return node.text[0]

    def generic_visit(self, node, visited_children):
        """ The generic visit method. """
        if not visited_children:
            return node
        if len(visited_children) == 1:
            return visited_children[0]
        return visited_children


def compile_to_ssa(parse_tree):
    ssa_gen = SsaGenVisitor()
    return ssa_gen.visit(parse_tree)
from dataclasses import dataclass
from functools import singledispatch
from typing import Dict, List

from .ast import ImplementationFile, FunctionDefinition, ExpressionStatement, FunctionCallExpression, \
    VariableExpression, ConstantExpression, ReturnStatement, BasicType


@dataclass
class SsaResult:
    data: List[str]
    code: List[str]


@dataclass
class CompileContext:
    next_data: int = 0
    next_temp: int = 0


def build_ssa(file: ImplementationFile) -> str:
    result = compile_to_ssa(file, CompileContext())
    data = '\n'.join(result.data)
    code = '\n'.join(result.code)
    return data + '\n\n' + code


@singledispatch
def compile_to_ssa(target, context: CompileContext) -> SsaResult:
    raise NotImplementedError('unannotated compile on ' + str(type(target)))


@compile_to_ssa.register
def _(target: ImplementationFile, context: CompileContext):
    data = []
    code = []
    for target in target.contents:
        result = compile_to_ssa(target, context)
        data += result.data
        code += result.code
    return SsaResult(data, code)


@compile_to_ssa.register
def _(target: FunctionDefinition, context: CompileContext) -> SsaResult:
    data = []
    code = []
    for statement in target.body:
        result = compile_to_ssa(statement, context)
        data += result.data
        code += result.code
    code = ['    ' + instr for instr in code]
    assert len(target.args) == 0
    assert target.return_type == BasicType('int32')
    code = [f"export function w ${target.name}() {{", "@start", *code, "}"]
    return SsaResult(data, code)


@compile_to_ssa.register
def _(target: ExpressionStatement, context: CompileContext) -> SsaResult:
    return compile_to_ssa(target.body, context)


@compile_to_ssa.register
def _(target: FunctionCallExpression, context: CompileContext) -> SsaResult:
    assert isinstance(target.function, VariableExpression)
    data = []
    code = []
    args = []
    for i, expr in enumerate(target.arguments):
        arg_dest = context.next_temp
        result = compile_to_ssa(expr, context)
        data += result.data
        code += result.code
        args += [f"l %t{arg_dest}"]
    code += [f"call ${target.function.name}({','.join(args)}, ...)"]
    return SsaResult(data, code)


@compile_to_ssa.register
def _(target: ConstantExpression, context: CompileContext) -> SsaResult:
    if target.value.startswith('"'):
        data_dest = context.next_data
        context.next_data += 1
        data = [f"data $data{data_dest} = {{ b {target.value}, b 0 }}"]
        temp = context.next_temp
        context.next_temp += 1
        code = [f"%t{temp} =l copy $data{data_dest}"]
    else:
        assert not target.value.startswith('0b')
        assert not target.value.startswith('0B')
        assert not target.value.startswith('0o')
        assert not target.value.startswith('0x')
        assert not target.value.startswith('0X')
        assert not target.value.startswith('0f')
        assert not target.value.startswith('0F')
        assert '.' not in target.value
        assert not target.value.startswith("'")
        data = []
        temp = context.next_temp
        context.next_temp += 1
        code = [f"%t{temp} =w copy {target.value}"]
    return SsaResult(data, code)


@compile_to_ssa.register
def _(target: ReturnStatement, context: CompileContext) -> SsaResult:
    if target.body is None:
        return SsaResult([], ['ret'])
    ret_val_dest = context.next_temp
    result = compile_to_ssa(target.body, context)
    result.code.append(f"ret %t{ret_val_dest}")
    return result

M tests/test_hello_world.py => tests/test_hello_world.py +8 -5
@@ 1,12 1,12 @@
import unittest

from crowbar_reference_compiler import compile_to_ssa, parse_header, parse_implementation, scan
from crowbar_reference_compiler import build_ast, build_ssa, parse_header, parse_implementation, scan


class TestHelloWorld(unittest.TestCase):
    def test_ssa(self):
        code = r"""
include "stdio.hro";
//include "stdio.hro";

int32 main() {
    printf("Hello, world!\n");


@@ 15,14 15,17 @@ int32 main() {
"""
        tokens = scan(code)
        parse_tree = parse_implementation(tokens)
        actual_ssa = compile_to_ssa(parse_tree)
        ast = build_ast(parse_tree, [])
        actual_ssa = build_ssa(ast)
        expected_ssa = r"""
data $data0 = { b "Hello, world!\n", b 0 }

export function w $main() {
@start
    call $printf(l $data0, ...)
    ret 0
    %t0 =l copy $data0
    call $printf(l %t0, ...)
    %t1 =w copy 0
    ret %t1
}
""".strip()
        self.assertEqual(expected_ssa, actual_ssa)