~boringcactus/crowbar-reference-compiler

eaf789901101b8958c2555a0af300d7471707fea — Melody Horn 1 year, 9 months ago 68c303c
preserve more info about declarations
M crowbar_reference_compiler/__init__.py => crowbar_reference_compiler/__init__.py +4 -3
@@ 16,8 16,8 @@ def main():
    args.add_argument('--stop-at-qbe-ssa', action='store_true')
    args.add_argument('-S', '--stop-at-assembly', action='store_true')
    args.add_argument('-c', '--stop-at-object', action='store_true')
    args.add_argument('-D', '--define-constant', help='define a constant with some literal value')
    args.add_argument('-I', '--include-dir', help='folder to look for included headers within')
    args.add_argument('-D', '--define-constant', action='append', help='define a constant with some literal value')
    args.add_argument('-I', '--include-dir', action='append', help='folder to look for included headers within')
    args.add_argument('-o', '--out', help='output file')
    args.add_argument('input', help='input file')



@@ 33,7 33,8 @@ def main():
            output_file.write(str(parse_tree))
        return

    decls = load_declarations(parse_tree)
    decls = load_declarations(parse_tree, args.include_dir)
    print(decls)

    ssa = compile_to_ssa(parse_tree)
    if args.stop_at_qbe_ssa:

M crowbar_reference_compiler/declarations.py => crowbar_reference_compiler/declarations.py +159 -12
@@ 1,4 1,6 @@
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple, Union

from parsimonious import NodeVisitor  # type: ignore



@@ 6,6 8,71 @@ from .scanner import scan
from .parser import parse_header


@dataclass
class Type:
    pass


@dataclass
class BasicType(Type):
    name: str


@dataclass
class PointerType(Type):
    target: Type


@dataclass
class ArrayType(Type):
    contents: Type
    size: int


@dataclass
class Declaration:
    name: str


@dataclass
class VariableDeclaration(Declaration):
    """Represents the declaration of a variable."""
    type: Type
    value: Optional[str]


@dataclass
class Declarations:
    included: List[Declaration]


@dataclass
class StructDeclaration(Declaration):
    """Represents the declaration of a struct type."""
    fields: Optional[List[VariableDeclaration]]


@dataclass
class EnumDeclaration(Declaration):
    """Represents the declaration of an enum type."""
    values: List[Tuple[str, Optional[int]]]


@dataclass
class UnionDeclaration(Declaration):
    """Represents the declaration of a union type."""
    tag: Optional[VariableDeclaration]
    cases: Union[List[VariableDeclaration], List[Tuple[str, Optional[VariableDeclaration]]]]


@dataclass
class FunctionDeclaration(Declaration):
    """Represents the declaration of a function."""
    return_type: Type
    args: List[VariableDeclaration]


# noinspection PyPep8Naming,PyMethodMayBeStatic,PyUnusedLocal
class DeclarationVisitor(NodeVisitor):
    def __init__(self, include_folders):
        self.data = []


@@ 16,13 83,15 @@ class DeclarationVisitor(NodeVisitor):
        return elements

    def visit_ImplementationFile(self, node, visited_children):
        return [x for x in visited_children if x is not None]
        includes, elements = visited_children
        includes = [y for x in includes for y in x]
        return [x for x in includes + elements if x is not None]

    def visit_IncludeStatement(self, node, visited_children):
        include, included_header, semicolon = visited_children
        assert include.text[0].type == 'include'
        assert included_header.type == 'string_literal'
        included_header = included_header.data
        included_header = included_header.data.strip('"')
        assert semicolon.text[0].type == ';'
        for include_folder in self.include_folders:
            header = Path(include_folder) / included_header


@@ 39,7 108,9 @@ class DeclarationVisitor(NodeVisitor):
        assert lbrace.text[0].type == '{'
        assert rbrace.text[0].type == '}'
        name = name.data
        return f"struct {name}"
        if not isinstance(fields, list):
            fields = [fields]
        return StructDeclaration(name, fields)

    def visit_OpaqueStructDefinition(self, node, visited_children):
        opaque, struct, name, semi = visited_children


@@ 47,7 118,7 @@ class DeclarationVisitor(NodeVisitor):
        assert struct.text[0].type == 'struct'
        assert semi.text[0].type == ';'
        name = name.data
        return f"struct {name}"
        return StructDeclaration(name, None)

    def visit_EnumDefinition(self, node, visited_children):
        enum, name, lbrace, first_member, extra_members, trailing_comma, rbrace = visited_children


@@ 55,7 126,18 @@ class DeclarationVisitor(NodeVisitor):
        assert lbrace.text[0].type == '{'
        assert rbrace.text[0].type == '}'
        name = name.data
        return f"enum {name}"
        values = [first_member]
        for _, v in extra_members:
            values.append(v)
        return EnumDeclaration(name, values)

    def visit_EnumMember(self, node, visited_children):
        name, equals_value = visited_children
        name = name.data
        if not isinstance(equals_value, list):
            return name, None
        _, value = equals_value
        return name, value

    def visit_RobustUnionDefinition(self, node, visited_children):
        union, name, lbrace, tag, body, rbrace = visited_children


@@ 63,7 145,40 @@ class DeclarationVisitor(NodeVisitor):
        assert lbrace.text[0].type == '{'
        assert rbrace.text[0].type == '}'
        name = name.data
        return f"union {name}"
        expected_tagname, body = body
        if tag.name != expected_tagname:
            raise NameError(f"tag {tag} does not match switch argument {expected_tagname}")
        if not isinstance(body, list):
            body = [body]
        return UnionDeclaration(name, tag, body)

    def visit_UnionBody(self, node, visited_children):
        switch, lparen, tag, rparen, lbrace, body, rbrace = visited_children
        assert switch.text[0].type == 'switch'
        assert lparen.text[0].type == '('
        assert rparen.text[0].type == ')'
        assert lbrace.text[0].type == '{'
        assert rbrace.text[0].type == '}'
        return tag.data, body

    def visit_UnionBodySet(self, node, visited_children):
        cases, var = visited_children
        if isinstance(cases, list):
            cases = cases[0]
        if isinstance(var, VariableDeclaration):
            return cases, var
        else:
            return cases, None

    def visit_CaseSpecifier(self, node, visited_children):
        while isinstance(visited_children, list) and len(visited_children) == 1:
            visited_children = visited_children[0]
        # TODO don't explode on 'default:'
        case, expr, colon = visited_children
        while isinstance(expr, list):
            expr = expr[0]
        # TODO don't explode on nontrivial expression
        return expr.data

    def visit_FragileUnionDefinition(self, node, visited_children):
        fragile, union, name, lbrace, body, rbrace = visited_children


@@ 72,7 187,7 @@ class DeclarationVisitor(NodeVisitor):
        assert lbrace.text[0].type == '{'
        assert rbrace.text[0].type == '}'
        name = name.data
        return f"union {name}"
        return UnionDeclaration(name, None, body)

    def visit_FunctionDeclaration(self, node, visited_children):
        signature, semi = visited_children


@@ 84,13 199,13 @@ class DeclarationVisitor(NodeVisitor):
        assert eq.text[0].type == '='
        assert semi.text[0].type == ';'
        name = name.data
        return name
        return VariableDeclaration(name, type, value)

    def visit_VariableDeclaration(self, node, visited_children):
        type, name, semi = visited_children
        assert semi.text[0].type == ';'
        name = name.data
        return name
        return VariableDeclaration(name, type, None)

    def visit_FunctionDefinition(self, node, visited_children):
        signature, body = visited_children


@@ 102,7 217,39 @@ class DeclarationVisitor(NodeVisitor):
        name = name.data
        assert lparen.text[0].type == '('
        assert rparen.text[0].type == ')'
        return return_type, name, args
        return FunctionDeclaration(name, return_type, args)

    def visit_BasicType(self, node, visited_children):
        while isinstance(visited_children, list) and len(visited_children) == 1:
            visited_children = visited_children[0]
        if isinstance(visited_children, list):
            if len(visited_children) == 3:
                # parenthesized!
                lparen, ty, rparen = visited_children
                assert lparen.text[0].type == '('
                assert rparen.text[0].type == ')'
                return ty
            else:
                category, name = visited_children
                category = category.text[0].type
                name = name.data
                return BasicType(f"{category} {name}")
        return BasicType(visited_children.text[0].type)

    def visit_ArrayType(self, node, visited_children):
        contents, lbracket, size, rbracket = visited_children
        assert lbracket.text[0].type == '['
        assert rbracket.text[0].type == ']'
        # TODO don't explode on nontrivial expression
        while isinstance(size, list):
            size = size[0]
        size = int(size.data)
        return ArrayType(contents, size)

    def visit_PointerType(self, node, visited_children):
        contents, splat = visited_children
        assert splat.text[0].type == '*'
        return PointerType(contents)

    def visit_constant(self, node, visited_children):
        return node.text[0]


@@ 122,6 269,6 @@ class DeclarationVisitor(NodeVisitor):
        return visited_children


def load_declarations(parse_tree):
    declarations = DeclarationVisitor([])
def load_declarations(parse_tree, include_dirs):
    declarations = DeclarationVisitor(include_dirs)
    return declarations.visit(parse_tree)

M crowbar_reference_compiler/scanner.py => crowbar_reference_compiler/scanner.py +1 -1
@@ 78,7 78,7 @@ def scan(code):
            remaining = remaining[id_match.end():]
            continue
        was_constant = False
        for constant in [DECIMAL_CONSTANT, BINARY_CONSTANT, OCTAL_CONSTANT, HEX_CONSTANT, FLOAT_CONSTANT, HEX_FLOAT_CONSTANT, CHAR_CONSTANT]:
        for constant in [HEX_CONSTANT, BINARY_CONSTANT, OCTAL_CONSTANT, HEX_FLOAT_CONSTANT, FLOAT_CONSTANT, DECIMAL_CONSTANT, CHAR_CONSTANT]:
            match = constant.match(remaining)
            if match:
                result.append(Token('constant', match.group()))

M tests/test_declarations.py => tests/test_declarations.py +17 -2
@@ 1,6 1,8 @@
import unittest

from crowbar_reference_compiler import compile_to_ssa, load_declarations, parse_header, parse_implementation, scan
from crowbar_reference_compiler.declarations import ArrayType, BasicType, EnumDeclaration, PointerType, \
    StructDeclaration, UnionDeclaration, VariableDeclaration


class TestDeclarationLoading(unittest.TestCase):


@@ 8,6 10,7 @@ class TestDeclarationLoading(unittest.TestCase):
        code = r"""
struct normal {
    bool fake;
    (uint8[3])* data;
}

opaque struct ope;


@@ 31,8 34,20 @@ fragile union not_robust {
"""
        tokens = scan(code)
        parse_tree = parse_header(tokens)
        decls = load_declarations(parse_tree)
        self.assertListEqual(decls, ['struct normal', 'struct ope', 'enum sample', 'union robust', 'union not_robust'])
        decls = load_declarations(parse_tree, [])
        normal = StructDeclaration('normal', [
            VariableDeclaration('fake', BasicType('bool'), None),
            VariableDeclaration('data', PointerType(ArrayType(BasicType('uint8'), 3)), None),
        ])
        ope = StructDeclaration('ope', None)
        sample = EnumDeclaration('sample', [('Testing', None)])
        robust = UnionDeclaration('robust', VariableDeclaration('tag', BasicType('enum sample'), None),
                                  [('Testing', VariableDeclaration('testPassed', BasicType('bool'), None))])
        not_robust = UnionDeclaration('not_robust', None,
                                      [VariableDeclaration('sample', BasicType('int8'), None),
                                       VariableDeclaration('nope', BasicType('bool'), None)])
        self.assertListEqual(decls, [normal, ope, sample, robust, not_robust])


if __name__ == '__main__':
    unittest.main()