~boringcactus/crowbar-reference-compiler

bd933eeef043e0bc5e3ddd6557c9884593b59d3b — Melody Horn 1 year, 6 months ago 67e4837
add some more instructions
1 files changed, 86 insertions(+), 24 deletions(-)

M crowbar_reference_compiler/ssagen.py
M crowbar_reference_compiler/ssagen.py => crowbar_reference_compiler/ssagen.py +86 -24
@@ 1,9 1,9 @@
from dataclasses import dataclass
from functools import singledispatch
from typing import Dict, List
from typing import List

from .ast import ImplementationFile, FunctionDefinition, ExpressionStatement, FunctionCallExpression, \
    VariableExpression, ConstantExpression, ReturnStatement, BasicType
    VariableExpression, ConstantExpression, ReturnStatement, BasicType, IfStatement, ComparisonExpression, AddExpression


@dataclass


@@ 11,11 11,23 @@ class SsaResult:
    data: List[str]
    code: List[str]

    def __add__(self, other: 'SsaResult') -> 'SsaResult':
        if not isinstance(other, SsaResult):
            return NotImplemented
        return SsaResult(self.data + other.data, self.code + other.code)

    def __radd__(self, other: 'SsaResult'):
        if not isinstance(other, SsaResult):
            return NotImplemented
        self.data += other.data
        self.code += other.code


@dataclass
class CompileContext:
    next_data: int = 0
    next_temp: int = 0
    next_label: int = 0


def build_ssa(file: ImplementationFile) -> str:


@@ 32,28 44,22 @@ def compile_to_ssa(target, context: CompileContext) -> SsaResult:

@compile_to_ssa.register
def _(target: ImplementationFile, context: CompileContext):
    data = []
    code = []
    result = SsaResult([], [])
    for target in target.contents:
        result = compile_to_ssa(target, context)
        data += result.data
        code += result.code
    return SsaResult(data, code)
        result += compile_to_ssa(target, context)
    return result


@compile_to_ssa.register
def _(target: FunctionDefinition, context: CompileContext) -> SsaResult:
    data = []
    code = []
    result = SsaResult([], [])
    for statement in target.body:
        result = compile_to_ssa(statement, context)
        data += result.data
        code += result.code
    code = ['    ' + instr for instr in code]
        result += compile_to_ssa(statement, context)
    code = ['    ' + instr for instr in result.code]
    assert len(target.args) == 0
    assert target.return_type == BasicType('int32')
    code = [f"export function w ${target.name}() {{", "@start", *code, "}"]
    return SsaResult(data, code)
    return SsaResult(result.data, code)


@compile_to_ssa.register


@@ 64,17 70,14 @@ def _(target: ExpressionStatement, context: CompileContext) -> SsaResult:
@compile_to_ssa.register
def _(target: FunctionCallExpression, context: CompileContext) -> SsaResult:
    assert isinstance(target.function, VariableExpression)
    data = []
    code = []
    result = SsaResult([], [])
    args = []
    for i, expr in enumerate(target.arguments):
        arg_dest = context.next_temp
        result = compile_to_ssa(expr, context)
        data += result.data
        code += result.code
        result += compile_to_ssa(expr, context)
        arg_dest = context.next_temp - 1
        args += [f"l %t{arg_dest}"]
    code += [f"call ${target.function.name}({','.join(args)}, ...)"]
    return SsaResult(data, code)
    result.code.append(f"call ${target.function.name}({','.join(args)}, ...)")
    return result


@compile_to_ssa.register


@@ 107,7 110,66 @@ def _(target: ConstantExpression, context: CompileContext) -> SsaResult:
def _(target: ReturnStatement, context: CompileContext) -> SsaResult:
    if target.body is None:
        return SsaResult([], ['ret'])
    ret_val_dest = context.next_temp
    result = compile_to_ssa(target.body, context)
    ret_val_dest = context.next_temp - 1
    result.code.append(f"ret %t{ret_val_dest}")
    return result


@compile_to_ssa.register
def _(target: IfStatement, context: CompileContext) -> SsaResult:
    result = compile_to_ssa(target.condition, context)
    condition_dest = context.next_temp - 1
    true_label = context.next_label
    context.next_label += 1
    false_label = context.next_label
    context.next_label += 1
    after_label = context.next_label
    context.next_label += 1
    result.code.append(f"jnz %t{condition_dest}, @l{true_label}, @l{false_label}")
    result.code.append(f"@l{true_label}")
    for statement in target.then:
        result += compile_to_ssa(statement, context)
    result.code.append(f"jmp @l{after_label}")
    result.code.append(f"@l{false_label}")
    for statement in target.els:
        result += compile_to_ssa(statement, context)
    result.code.append(f"jmp @l{after_label}")
    result.code.append(f"@l{after_label}")
    return result


@compile_to_ssa.register
def _(target: ComparisonExpression, context: CompileContext) -> SsaResult:
    result = compile_to_ssa(target.value1, context)
    value1_dest = context.next_temp - 1
    result += compile_to_ssa(target.value2, context)
    value2_dest = context.next_temp - 1
    result_dest = context.next_temp
    context.next_temp += 1
    if target.op == '==':
        result.code.append(f"%t{result_dest} =w ceq %t{value1_dest}, %t{value2_dest}")
    else:
        raise NotImplementedError('comparison ' + target.op)
    return result


@compile_to_ssa.register
def _(target: AddExpression, context: CompileContext) -> SsaResult:
    result = compile_to_ssa(target.term1, context)
    value1_dest = context.next_temp - 1
    result += compile_to_ssa(target.term2, context)
    value2_dest = context.next_temp - 1
    result_reg = context.next_temp
    context.next_temp += 1
    # TODO make sure the types are correct
    result.code.append(f"%t{result_reg} =w add %t{value1_dest}, %t{value2_dest}")
    return result


@compile_to_ssa.register
def _(target: VariableExpression, context: CompileContext) -> SsaResult:
    # TODO make sure any of this is reasonable
    result = context.next_temp
    context.next_temp += 1
    return SsaResult([], [f"%t{result} =w copy %{target.name}"])