~boringcactus/crowbar-reference-compiler

68c303c285032e2f0398352b70eb79736a72296f — Melody Horn 1 year, 9 months ago ab58859
add ability to extract declarations from file
M crowbar_reference_compiler/__init__.py => crowbar_reference_compiler/__init__.py +4 -0
@@ 1,3 1,4 @@
from .declarations import load_declarations
from .parser import parse_header, parse_implementation
from .scanner import scan
from .ssagen import compile_to_ssa


@@ 31,6 32,9 @@ def main():
        with open(args.out, 'w', encoding='utf-8') as output_file:
            output_file.write(str(parse_tree))
        return

    decls = load_declarations(parse_tree)

    ssa = compile_to_ssa(parse_tree)
    if args.stop_at_qbe_ssa:
        if args.out is None:

A crowbar_reference_compiler/declarations.py => crowbar_reference_compiler/declarations.py +127 -0
@@ 0,0 1,127 @@
from pathlib import Path

from parsimonious import NodeVisitor  # type: ignore

from .scanner import scan
from .parser import parse_header


class DeclarationVisitor(NodeVisitor):
    def __init__(self, include_folders):
        self.data = []
        self.include_folders = include_folders

    def visit_HeaderFile(self, node, visited_children):
        includes, elements = visited_children
        return elements

    def visit_ImplementationFile(self, node, visited_children):
        return [x for x in visited_children if x is not None]

    def visit_IncludeStatement(self, node, visited_children):
        include, included_header, semicolon = visited_children
        assert include.text[0].type == 'include'
        assert included_header.type == 'string_literal'
        included_header = included_header.data
        assert semicolon.text[0].type == ';'
        for include_folder in self.include_folders:
            header = Path(include_folder) / included_header
            if header.exists():
                with open(header, 'r', encoding='utf-8') as header_file:
                    header_text = header_file.read()
                header_parse_tree = parse_header(scan(header_text))
                return self.visit(header_parse_tree)
        raise FileNotFoundError(included_header)

    def visit_NormalStructDefinition(self, node, visited_children):
        struct, name, lbrace, fields, rbrace = visited_children
        assert struct.text[0].type == 'struct'
        assert lbrace.text[0].type == '{'
        assert rbrace.text[0].type == '}'
        name = name.data
        return f"struct {name}"

    def visit_OpaqueStructDefinition(self, node, visited_children):
        opaque, struct, name, semi = visited_children
        assert opaque.text[0].type == 'opaque'
        assert struct.text[0].type == 'struct'
        assert semi.text[0].type == ';'
        name = name.data
        return f"struct {name}"

    def visit_EnumDefinition(self, node, visited_children):
        enum, name, lbrace, first_member, extra_members, trailing_comma, rbrace = visited_children
        assert enum.text[0].type == 'enum'
        assert lbrace.text[0].type == '{'
        assert rbrace.text[0].type == '}'
        name = name.data
        return f"enum {name}"

    def visit_RobustUnionDefinition(self, node, visited_children):
        union, name, lbrace, tag, body, rbrace = visited_children
        assert union.text[0].type == 'union'
        assert lbrace.text[0].type == '{'
        assert rbrace.text[0].type == '}'
        name = name.data
        return f"union {name}"

    def visit_FragileUnionDefinition(self, node, visited_children):
        fragile, union, name, lbrace, body, rbrace = visited_children
        assert fragile.text[0].type == 'fragile'
        assert union.text[0].type == 'union'
        assert lbrace.text[0].type == '{'
        assert rbrace.text[0].type == '}'
        name = name.data
        return f"union {name}"

    def visit_FunctionDeclaration(self, node, visited_children):
        signature, semi = visited_children
        assert semi.text[0].type == ';'
        return signature

    def visit_VariableDefinition(self, node, visited_children):
        type, name, eq, value, semi = visited_children
        assert eq.text[0].type == '='
        assert semi.text[0].type == ';'
        name = name.data
        return name

    def visit_VariableDeclaration(self, node, visited_children):
        type, name, semi = visited_children
        assert semi.text[0].type == ';'
        name = name.data
        return name

    def visit_FunctionDefinition(self, node, visited_children):
        signature, body = visited_children
        return signature

    def visit_FunctionSignature(self, node, visited_children):
        return_type, name, lparen, args, rparen = visited_children
        assert name.type == 'identifier'
        name = name.data
        assert lparen.text[0].type == '('
        assert rparen.text[0].type == ')'
        return return_type, name, args

    def visit_constant(self, node, visited_children):
        return node.text[0]

    def visit_string_literal(self, node, visited_children):
        return node.text[0]

    def visit_identifier(self, node, visited_children):
        return node.text[0]

    def generic_visit(self, node, visited_children):
        """ The generic visit method. """
        if not visited_children:
            return node
        if len(visited_children) == 1:
            return visited_children[0]
        return visited_children


def load_declarations(parse_tree):
    declarations = DeclarationVisitor([])
    return declarations.visit(parse_tree)

M crowbar_reference_compiler/scanner.py => crowbar_reference_compiler/scanner.py +1 -1
@@ 49,7 49,7 @@ HEX_FLOAT_CONSTANT = re.compile(r"0(fx|FX)[0-9a-fA-F_]+\.[0-9a-fA-F_]+[pP][+-]?[
_ESCAPE_SEQUENCE = r"""\\['"\\rnt0]|\\x[0-9a-fA-F]{2}|\\u[0-9a-fA-F]{4}|\\U[0-9a-fA-F]{8}"""
CHAR_CONSTANT = re.compile(r"'([^'\\]|" + _ESCAPE_SEQUENCE + r")'")
STRING_LITERAL = re.compile(r'"([^"\\]|' + _ESCAPE_SEQUENCE + r')+"')
PUNCTUATOR = re.compile(r"->|\+\+|--|>>|<<|<=|>=|&&|\|\||[=!+\-*/%&|^]=|[\[\](){}.,+\-*/%;!&|^~><=]")
PUNCTUATOR = re.compile(r"->|\+\+|--|>>|<<|<=|>=|&&|\|\||[=!+\-*/%&|^]=|[\[\](){}.,+\-*/%;:!&|^~><=]")
WHITESPACE = re.compile(r"[\p{Z}\p{Cc}]+")
COMMENT = re.compile(r"(//[^\n]*\n)|(/\*.*?\*/)", re.DOTALL)


A tests/test_declarations.py => tests/test_declarations.py +38 -0
@@ 0,0 1,38 @@
import unittest

from crowbar_reference_compiler import compile_to_ssa, load_declarations, parse_header, parse_implementation, scan


class TestDeclarationLoading(unittest.TestCase):
    def test_kitchen_sink(self):
        code = r"""
struct normal {
    bool fake;
}

opaque struct ope;

enum sample {
    Testing,
}

union robust {
    enum sample tag;
    
    switch (tag) {
        case Testing: bool testPassed;
    }
}

fragile union not_robust {
    int8 sample;
    bool nope;
}
"""
        tokens = scan(code)
        parse_tree = parse_header(tokens)
        decls = load_declarations(parse_tree)
        self.assertListEqual(decls, ['struct normal', 'struct ope', 'enum sample', 'union robust', 'union not_robust'])

if __name__ == '__main__':
    unittest.main()