~boringcactus/crowbar-reference-compiler

c7de22f575607a7966b6b592dbf81bd3f867a2e4 — Melody Horn 1 year, 4 months ago bd933ee
implement a bunch more stuff
M crowbar_reference_compiler/__init__.py => crowbar_reference_compiler/__init__.py +1 -1
@@ 62,7 62,7 @@ def main():
        if args.out is None:
            args.out = args.input.replace('.cro', '.o')
        extra_gcc_flags.append('-c')
    gcc_result = subprocess.run(['gcc', '-x', 'assembler', '-o', args.out, '-'], input=asm, text=True)
    gcc_result = subprocess.run(['gcc', '-x', 'assembler', '-o', args.out, *extra_gcc_flags, '-'], input=asm, text=True)
    sys.exit(gcc_result.returncode)



M crowbar_reference_compiler/ast.py => crowbar_reference_compiler/ast.py +88 -2
@@ 12,23 12,48 @@ from .parser import parse_header

@dataclass
class Type:
    pass
    def size_bytes(self, declarations: List['Declaration']) -> int:
        raise NotImplementedError('type.size_bytes() on ' + str(type(self)) + ' not implemented')


@dataclass
class Expression:
    pass
    def type(self, declarations: List['Declaration']) -> Type:
        raise NotImplementedError('expression.type() on ' + str(type(self)) + ' not implemented')


@dataclass
class ConstantExpression(Expression):
    value: str

    def type(self, _: List['Declaration']) -> Type:
        if self.value.startswith('"'):
            return PointerType(ConstType(BasicType('char')))
        elif self.value.startswith("'"):
            return BasicType('char')
        elif self.value in ['true', 'false']:
            return BasicType('bool')
        elif '.' in self.value:
            return BasicType('float?') # TODO infer size
        else:
            return BasicType('int?') # TODO infer size and signedness


@dataclass
class VariableExpression(Expression):
    name: str

    def type(self, declarations: List['Declaration']) -> Type:
        for decl in declarations:
            if decl.name == self.name:
                if isinstance(decl, VariableDeclaration):
                    return decl.type
                elif isinstance(decl, VariableDefinition):
                    return decl.type
                elif isinstance(decl, FunctionDeclaration) or isinstance(decl, FunctionDefinition):
                    return FunctionType(decl.return_type, [arg.type for arg in decl.args])
        raise KeyError('unknown variable ' + self.name)


@dataclass
class AddExpression(Expression):


@@ 37,6 62,12 @@ class AddExpression(Expression):


@dataclass
class SubtractExpression(Expression):
    term1: Expression
    term2: Expression


@dataclass
class MultiplyExpression(Expression):
    factor1: Expression
    factor2: Expression


@@ 47,6 78,22 @@ class StructPointerElementExpression(Expression):
    base: Expression
    element: str

    def type(self, declarations: List['Declaration']) -> Type:
        base_type = self.base.type(declarations)
        assert isinstance(base_type, PointerType)
        assert isinstance(base_type.target, BasicType)
        hopefully_struct, struct_name = base_type.target.name.split(' ')
        assert hopefully_struct == 'struct'
        for decl in declarations:
            if isinstance(decl, StructDeclaration) and decl.name == struct_name:
                if decl.fields is None:
                    raise KeyError('struct ' + struct_name + ' is opaque')
                for elem in decl.fields:
                    if elem.name == self.element:
                        return elem.type
                raise KeyError('element ' + self.element + ' not found in struct ' + struct_name)
        raise KeyError('struct ' + struct_name + ' not found')


@dataclass
class ArrayIndexExpression(Expression):


@@ 91,6 138,20 @@ class ComparisonExpression(Expression):
class BasicType(Type):
    name: str

    def size_bytes(self, declarations: List['Declaration']) -> int:
        if self.name == 'uint8':
            return 1
        elif self.name == 'uintsize':
            return 8
        elif self.name.startswith('struct'):
            _, struct_name = self.name.split(' ')
            for decl in declarations:
                if isinstance(decl, StructDeclaration) and decl.name == struct_name:
                    if decl.fields is None:
                        raise KeyError('struct ' + struct_name + ' is opaque')
                    return sum(field.type.size_bytes(declarations) for field in decl.fields)
        raise NotImplementedError('size of ' + str(self) + ' not yet found')


@dataclass
class ConstType(Type):


@@ 101,6 162,9 @@ class ConstType(Type):
class PointerType(Type):
    target: Type

    def size_bytes(self, declarations: List['Declaration']) -> int:
        return 8 # TODO figure out 32 bit vs 64 bit


@dataclass
class ArrayType(Type):


@@ 226,6 290,14 @@ class UpdateAssignment(AssignmentStatement):
    operation: str
    value: Expression

    def deconstruct(self) -> DirectAssignment:
        if self.operation == '+=':
            return DirectAssignment(self.destination, AddExpression(self.destination, self.value))
        elif self.operation == '*=':
            return DirectAssignment(self.destination, MultiplyExpression(self.destination, self.value))
        else:
            raise NotImplementedError('UpdateAssignment deconstruct with ' + self.operation)


@dataclass
class CrementAssignment(AssignmentStatement):


@@ 273,12 345,24 @@ class HeaderFile:
    includes: List['HeaderFile']
    contents: List[HeaderFileElement]

    def get_declarations(self) -> List[Declaration]:
        included_declarations = [x.get_declarations() for x in self.includes]
        own_declarations = [x for x in self.contents if isinstance(x, Declaration)]
        all_declarations = included_declarations + [own_declarations]
        return [x for l in all_declarations for x in l]


@dataclass
class ImplementationFile:
    includes: List[HeaderFile]
    contents: List[ImplementationFileElement]

    def get_declarations(self) -> List[Declaration]:
        included_declarations = [x.get_declarations() for x in self.includes]
        own_declarations = [x for x in self.contents if isinstance(x, Declaration)]
        all_declarations = included_declarations + [own_declarations]
        return [x for l in all_declarations for x in l]


# noinspection PyPep8Naming,PyMethodMayBeStatic,PyUnusedLocal
class ASTBuilder(NodeVisitor):


@@ 605,6 689,8 @@ class ASTBuilder(NodeVisitor):
            for op, term in suffix:
                if op.type == '+':
                    base = AddExpression(base, term)
                elif op.type == '-':
                    base = SubtractExpression(base, term)
                else:
                    raise NotImplementedError('arithmetic suffix ' + op)
        return base

M crowbar_reference_compiler/ssagen.py => crowbar_reference_compiler/ssagen.py +206 -15
@@ 1,9 1,13 @@
import dataclasses
from dataclasses import dataclass
from functools import singledispatch
from typing import List

from .ast import ImplementationFile, FunctionDefinition, ExpressionStatement, FunctionCallExpression, \
    VariableExpression, ConstantExpression, ReturnStatement, BasicType, IfStatement, ComparisonExpression, AddExpression
    VariableExpression, ConstantExpression, ReturnStatement, BasicType, IfStatement, ComparisonExpression, \
    AddExpression, StructPointerElementExpression, Declaration, PointerType, StructDeclaration, VariableDefinition, \
    MultiplyExpression, LogicalNotExpression, DirectAssignment, UpdateAssignment, SizeofExpression, Expression, \
    ConstType, ArrayIndexExpression, ArrayType, NegativeExpression, SubtractExpression, AddressOfExpression


@dataclass


@@ 25,13 29,14 @@ class SsaResult:

@dataclass
class CompileContext:
    declarations: List[Declaration]
    next_data: int = 0
    next_temp: int = 0
    next_label: int = 0


def build_ssa(file: ImplementationFile) -> str:
    result = compile_to_ssa(file, CompileContext())
    result = compile_to_ssa(file, CompileContext(file.get_declarations()))
    data = '\n'.join(result.data)
    code = '\n'.join(result.code)
    return data + '\n\n' + code


@@ 53,12 58,20 @@ def _(target: ImplementationFile, context: CompileContext):
@compile_to_ssa.register
def _(target: FunctionDefinition, context: CompileContext) -> SsaResult:
    result = SsaResult([], [])
    context = dataclasses.replace(context, declarations=target.args + context.declarations)
    for statement in target.body:
        result += compile_to_ssa(statement, context)
        if isinstance(statement, Declaration):
            context = dataclasses.replace(context, declarations=[statement]+context.declarations)
    if not result.code[-1].startswith('ret'):
        result.code.append('ret')
    code = ['    ' + instr for instr in result.code]
    assert len(target.args) == 0
    assert target.return_type == BasicType('int32')
    code = [f"export function w ${target.name}() {{", "@start", *code, "}"]
    # TODO types
    args = ','.join(f"l %{x.name}" for x in target.args)
    ret_type = ''
    if target.return_type != BasicType('void'):
        ret_type = 'l'
    code = [f"export function {ret_type} ${target.name}({args}) {{", "@start", *code, "}"]
    return SsaResult(result.data, code)




@@ 82,14 95,28 @@ def _(target: FunctionCallExpression, context: CompileContext) -> SsaResult:

@compile_to_ssa.register
def _(target: ConstantExpression, context: CompileContext) -> SsaResult:
    if target.value.startswith('"'):
    if target.type(context.declarations) == PointerType(ConstType(BasicType('char'))):
        data_dest = context.next_data
        context.next_data += 1
        data = [f"data $data{data_dest} = {{ b {target.value}, b 0 }}"]
        temp = context.next_temp
        context.next_temp += 1
        code = [f"%t{temp} =l copy $data{data_dest}"]
    else:
    elif target.type(context.declarations) == BasicType('char'):
        data = []
        temp = context.next_temp
        context.next_temp += 1
        code = [f"%t{temp} =l copy {ord(target.value[1])}"] # TODO handle escape sequences
    elif target.type(context.declarations) == BasicType('bool'):
        data = []
        temp = context.next_temp
        context.next_temp += 1
        if target.value == 'true':
            value = 1
        else:
            value = 0
        code = [f"%t{temp} =l copy {value}"]
    elif target.type(context.declarations) == BasicType('int?'):
        assert not target.value.startswith('0b')
        assert not target.value.startswith('0B')
        assert not target.value.startswith('0o')


@@ 102,7 129,9 @@ def _(target: ConstantExpression, context: CompileContext) -> SsaResult:
        data = []
        temp = context.next_temp
        context.next_temp += 1
        code = [f"%t{temp} =w copy {target.value}"]
        code = [f"%t{temp} =l copy {target.value}"]
    else:
        raise NotImplementedError('compiling ' + str(target))
    return SsaResult(data, code)




@@ 130,11 159,14 @@ def _(target: IfStatement, context: CompileContext) -> SsaResult:
    result.code.append(f"@l{true_label}")
    for statement in target.then:
        result += compile_to_ssa(statement, context)
    result.code.append(f"jmp @l{after_label}")
    if not result.code[-1].startswith('ret'):
        result.code.append(f"jmp @l{after_label}")
    result.code.append(f"@l{false_label}")
    for statement in target.els:
        result += compile_to_ssa(statement, context)
    result.code.append(f"jmp @l{after_label}")
    if target.els is not None:
        for statement in target.els:
            result += compile_to_ssa(statement, context)
    if not result.code[-1].startswith('ret'):
        result.code.append(f"jmp @l{after_label}")
    result.code.append(f"@l{after_label}")
    return result



@@ 147,10 179,16 @@ def _(target: ComparisonExpression, context: CompileContext) -> SsaResult:
    value2_dest = context.next_temp - 1
    result_dest = context.next_temp
    context.next_temp += 1
    # TODO types, and signedness
    if target.op == '==':
        result.code.append(f"%t{result_dest} =w ceq %t{value1_dest}, %t{value2_dest}")
        op = "ceqw"
    elif target.op == '>=':
        op = "cugew"
    elif target.op == '<=':
        op = "culew"
    else:
        raise NotImplementedError('comparison ' + target.op)
    result.code.append(f"%t{result_dest} =l {op} %t{value1_dest}, %t{value2_dest}")
    return result




@@ 163,7 201,33 @@ def _(target: AddExpression, context: CompileContext) -> SsaResult:
    result_reg = context.next_temp
    context.next_temp += 1
    # TODO make sure the types are correct
    result.code.append(f"%t{result_reg} =w add %t{value1_dest}, %t{value2_dest}")
    result.code.append(f"%t{result_reg} =l add %t{value1_dest}, %t{value2_dest}")
    return result


@compile_to_ssa.register
def _(target: SubtractExpression, context: CompileContext) -> SsaResult:
    result = compile_to_ssa(target.term1, context)
    value1_dest = context.next_temp - 1
    result += compile_to_ssa(target.term2, context)
    value2_dest = context.next_temp - 1
    result_reg = context.next_temp
    context.next_temp += 1
    # TODO make sure the types are correct
    result.code.append(f"%t{result_reg} =l sub %t{value1_dest}, %t{value2_dest}")
    return result


@compile_to_ssa.register
def _(target: MultiplyExpression, context: CompileContext) -> SsaResult:
    result = compile_to_ssa(target.factor1, context)
    value1_dest = context.next_temp - 1
    result += compile_to_ssa(target.factor2, context)
    value2_dest = context.next_temp - 1
    result_reg = context.next_temp
    context.next_temp += 1
    # TODO make sure the types are correct
    result.code.append(f"%t{result_reg} =l mul %t{value1_dest}, %t{value2_dest}")
    return result




@@ 172,4 236,131 @@ def _(target: VariableExpression, context: CompileContext) -> SsaResult:
    # TODO make sure any of this is reasonable
    result = context.next_temp
    context.next_temp += 1
    return SsaResult([], [f"%t{result} =w copy %{target.name}"])
    return SsaResult([], [f"%t{result} =l copy %{target.name}"])


@compile_to_ssa.register
def _(target: VariableDefinition, context: CompileContext) -> SsaResult:
    # TODO figure some shit out
    result = compile_to_ssa(target.value, context)
    result_dest = context.next_temp - 1
    result.code.append(f"%{target.name} =l copy %t{result_dest}")
    return result


@compile_to_ssa.register
def _(target: LogicalNotExpression, context: CompileContext) -> SsaResult:
    result = compile_to_ssa(target.body, context)
    inner_result_dest = context.next_temp - 1
    result_dest = context.next_temp
    context.next_temp += 1
    result.code.append(f"%t{result_dest} =l ceqw %t{inner_result_dest}, 0")
    return result


@compile_to_ssa.register
def _(target: NegativeExpression, context: CompileContext) -> SsaResult:
    return compile_to_ssa(SubtractExpression(ConstantExpression('0'), target.body), context)


@compile_to_ssa.register
def _(target: ArrayIndexExpression, context: CompileContext) -> SsaResult:
    result = compile_to_ssa(target.array, context)
    base = context.next_temp - 1
    result += compile_to_ssa(target.index, context)
    index = context.next_temp - 1
    array_type = target.array.type(context.declarations)
    if isinstance(array_type, PointerType):
        array_type = array_type.target
    assert isinstance(array_type, ArrayType)
    content_type = array_type.contents
    scale = content_type.size_bytes(context.declarations)
    offset = context.next_temp
    context.next_temp += 1
    address = context.next_temp
    context.next_temp += 1
    dest = context.next_temp
    context.next_temp += 1
    # TODO types
    result.code.append(f"%t{offset} =l mul %t{index}, {scale}")
    result.code.append(f"%t{address} =l add %t{base}, %t{offset}")
    result.code.append(f"%t{dest} =l loadsw %t{address}")
    return result


@compile_to_ssa.register
def _(target: StructPointerElementExpression, context: CompileContext) -> SsaResult:
    result = compile_to_ssa(target.base, context)
    base_dest = context.next_temp - 1
    # hoooo boy.
    base_type = target.base.type(context.declarations)
    assert isinstance(base_type, PointerType)
    assert isinstance(base_type.target, BasicType)
    hopefully_struct, struct_name = base_type.target.name.split(' ')
    assert hopefully_struct == 'struct'
    target_struct = None
    for decl in context.declarations:
        if isinstance(decl, StructDeclaration) and decl.name == struct_name:
            if decl.fields is None:
                raise KeyError('struct ' + struct_name + ' is opaque')
            target_struct = decl
            break
    if target_struct is None:
        raise KeyError('struct ' + struct_name + ' not found')
    offset = 0
    for field in target_struct.fields:
        if field.name == target.element:
            break
        else:
            offset += field.type.size_bytes(context.declarations)
    temp = context.next_temp
    context.next_temp += 1
    result_dest = context.next_temp
    context.next_temp += 1
    # TODO types
    result.code.append(f"%t{temp} =l add %t{base_dest}, {offset}")
    result.code.append(f"%t{result_dest} =l loadsw %t{temp}")
    return result


@compile_to_ssa.register
def _(target: AddressOfExpression, context: CompileContext) -> SsaResult:
    if isinstance(target.body, StructPointerElementExpression) or isinstance(target.body, ArrayIndexExpression):
        result = compile_to_ssa(target.body, context)
        result.code.pop()
        context.next_temp -= 1
    else:
        raise NotImplementedError('address of ' + str(type(target.body)))
    return result


@compile_to_ssa.register
def _(target: DirectAssignment, context: CompileContext) -> SsaResult:
    result = compile_to_ssa(target.value, context)
    result_dest = context.next_temp - 1
    if isinstance(target.destination, VariableExpression):
        raise NotImplementedError('assign directly to variable')
    elif isinstance(target.destination, StructPointerElementExpression) or isinstance(target.destination, ArrayIndexExpression):
        sub_result = compile_to_ssa(target.destination, context)
        last_instr = sub_result.code.pop()
        _, _, _, location = last_instr.split(' ')
        # TODO type
        sub_result.code.append(f"storew %t{result_dest}, {location}")
        result += sub_result
    else:
        raise NotImplementedError('assign to ' + str(type(target.destination)))
    return result


@compile_to_ssa.register
def _(target: UpdateAssignment, context: CompileContext) -> SsaResult:
    return compile_to_ssa(target.deconstruct(), context)


@compile_to_ssa.register
def _(target: SizeofExpression, context: CompileContext) -> SsaResult:
    target = target.body
    if isinstance(target, Expression):
        target = target.type(context.declarations)
    size = target.size_bytes(context.declarations)
    return compile_to_ssa(ConstantExpression(str(size)), context)