~boringcactus/crowbar-reference-compiler

9dfc552c0703c5e14ea472eb5431719b2e0d6400 — Melody Horn 1 year, 9 months ago 3597995
test AST with the full power of our scdoc translation
3 files changed, 598 insertions(+), 108 deletions(-)

M crowbar_reference_compiler/ast.py
M tests/test_ast.py
M tests/test_parsing.py
M crowbar_reference_compiler/ast.py => crowbar_reference_compiler/ast.py +245 -38
@@ 1,8 1,10 @@
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple, Union
import typing
from typing import ClassVar, List, Tuple, Union

from parsimonious import NodeVisitor  # type: ignore
from parsimonious.expressions import Compound, OneOf, Optional, Sequence, TokenMatcher, ZeroOrMore  # type: ignore

from .scanner import scan
from .parser import parse_header


@@ 29,6 31,63 @@ class VariableExpression(Expression):


@dataclass
class AddExpression(Expression):
    term1: Expression
    term2: Expression


@dataclass
class MultiplyExpression(Expression):
    factor1: Expression
    factor2: Expression


@dataclass
class StructPointerElementExpression(Expression):
    base: Expression
    element: str


@dataclass
class ArrayIndexExpression(Expression):
    array: Expression
    index: Expression


@dataclass
class FunctionCallExpression(Expression):
    function: Expression
    arguments: List[Expression]


@dataclass
class LogicalNotExpression(Expression):
    body: Expression


@dataclass
class NegativeExpression(Expression):
    body: Expression


@dataclass
class AddressOfExpression(Expression):
    body: Expression


@dataclass
class SizeofExpression(Expression):
    body: Union[Type, Expression]


@dataclass
class ComparisonExpression(Expression):
    value1: Expression
    op: str
    value2: Expression


@dataclass
class BasicType(Type):
    name: str



@@ 89,13 148,13 @@ class ExpressionStatement(Statement):
class IfStatement(Statement):
    condition: Expression
    then: List[Statement]
    els: Optional[List[Statement]]
    els: typing.Optional[List[Statement]]


@dataclass
class SwitchStatement(Statement):
    expression: Expression
    body: List[Union[Optional[Expression], Statement]]
    body: List[Union[typing.Optional[Expression], Statement]]


@dataclass


@@ 122,7 181,7 @@ class VariableDeclaration(Declaration, HeaderFileElement):


@dataclass
class VariableDefinition(Declaration, ImplementationFileElement, Statement):
class VariableDefinition(Declaration, HeaderFileElement, ImplementationFileElement, Statement):
    """Represents the definition of a variable."""
    type: Type
    value: Expression


@@ 152,7 211,7 @@ class BreakStatement(Statement):

@dataclass
class ReturnStatement(Statement):
    body: Optional[Expression]
    body: typing.Optional[Expression]


@dataclass


@@ 177,20 236,20 @@ class CrementAssignment(AssignmentStatement):
@dataclass
class StructDeclaration(Declaration, HeaderFileElement, ImplementationFileElement):
    """Represents the declaration of a struct type."""
    fields: Optional[List[VariableDeclaration]]
    fields: typing.Optional[List[VariableDeclaration]]


@dataclass
class EnumDeclaration(Declaration, HeaderFileElement, ImplementationFileElement):
    """Represents the declaration of an enum type."""
    values: List[Tuple[str, Optional[int]]]
    values: List[Tuple[str, typing.Optional[int]]]


@dataclass
class UnionDeclaration(Declaration, HeaderFileElement, ImplementationFileElement):
    """Represents the declaration of a union type."""
    tag: Optional[VariableDeclaration]
    cases: Union[List[VariableDeclaration], List[Tuple[Expression, Optional[VariableDeclaration]]]]
    tag: typing.Optional[VariableDeclaration]
    cases: Union[List[VariableDeclaration], List[Tuple[Expression, typing.Optional[VariableDeclaration]]]]


@dataclass


@@ 210,6 269,7 @@ class FunctionDefinition(Declaration, HeaderFileElement, ImplementationFileEleme

@dataclass
class HeaderFile:
    grammar: ClassVar[str] = "HeaderFile <- IncludeStatement* HeaderFileElement+"
    includes: List['HeaderFile']
    contents: List[HeaderFileElement]



@@ 227,8 287,6 @@ class ASTBuilder(NodeVisitor):

    def visit_HeaderFile(self, node, visited_children) -> HeaderFile:
        includes, elements = visited_children
        if not isinstance(includes, list):
            includes = []
        return HeaderFile(includes, elements)

    def visit_ImplementationFile(self, node, visited_children) -> ImplementationFile:


@@ 253,36 311,38 @@ class ASTBuilder(NodeVisitor):
    def visit_NormalStructDefinition(self, node, visited_children) -> StructDeclaration:
        struct, name, lbrace, fields, rbrace = visited_children
        assert struct.type == 'struct'
        assert name.type == 'identifier'
        name = name.data
        assert lbrace.type == '{'
        assert rbrace.type == '}'
        name = name.data
        if not isinstance(fields, list):
            fields = [fields]
        return StructDeclaration(name, fields)

    def visit_OpaqueStructDefinition(self, node, visited_children) -> StructDeclaration:
        opaque, struct, name, semi = visited_children
        assert opaque.type == 'opaque'
        assert struct.type == 'struct'
        assert semi.type == ';'
        assert name.type == 'identifier'
        name = name.data
        assert semi.type == ';'
        return StructDeclaration(name, None)

    def visit_EnumDefinition(self, node, visited_children) -> EnumDeclaration:
        enum, name, lbrace, first_member, extra_members, trailing_comma, rbrace = visited_children
        assert enum.type == 'enum'
        assert name.type == 'identifier'
        name = name.data
        assert lbrace.type == '{'
        assert rbrace.type == '}'
        name = name.data
        values = [first_member]
        for _, v in extra_members:
            values.append(v)
        return EnumDeclaration(name, values)

    def visit_EnumMember(self, node, visited_children) -> Tuple[str, Optional[Expression]]:
    def visit_EnumMember(self, node, visited_children) -> Tuple[str, typing.Optional[Expression]]:
        name, equals_value = visited_children
        assert name.type == 'identifier'
        name = name.data
        if len(equals_value) == 0:
        if equals_value is None:
            return name, None
        _, value = equals_value
        return name, value


@@ 290,9 350,10 @@ class ASTBuilder(NodeVisitor):
    def visit_RobustUnionDefinition(self, node, visited_children) -> UnionDeclaration:
        union, name, lbrace, tag, body, rbrace = visited_children
        assert union.type == 'union'
        assert name.type == 'identifier'
        name = name.data
        assert lbrace.type == '{'
        assert rbrace.type == '}'
        name = name.data
        expected_tagname, body = body
        if tag.name != expected_tagname:
            raise NameError(f"tag {tag} does not match switch argument {expected_tagname}")


@@ 300,7 361,7 @@ class ASTBuilder(NodeVisitor):
            body = [body]
        return UnionDeclaration(name, tag, body)

    def visit_UnionBody(self, node, visited_children) -> Tuple[str, List[Tuple[Expression, Optional[VariableDeclaration]]]]:
    def visit_UnionBody(self, node, visited_children) -> Tuple[str, List[Tuple[Expression, typing.Optional[VariableDeclaration]]]]:
        switch, lparen, tag, rparen, lbrace, body, rbrace = visited_children
        assert switch.type == 'switch'
        assert lparen.type == '('


@@ 309,7 370,7 @@ class ASTBuilder(NodeVisitor):
        assert rbrace.type == '}'
        return tag.data, body

    def visit_UnionBodySet(self, node, visited_children) -> Tuple[Expression, Optional[VariableDeclaration]]:
    def visit_UnionBodySet(self, node, visited_children) -> Tuple[Expression, typing.Optional[VariableDeclaration]]:
        cases, var = visited_children
        if isinstance(cases, list):
            cases = cases[0]


@@ 323,17 384,16 @@ class ASTBuilder(NodeVisitor):
            visited_children = visited_children[0]
        # TODO don't explode on 'default:'
        case, expr, colon = visited_children
        while isinstance(expr, list):
            expr = expr[0]
        return expr

    def visit_FragileUnionDefinition(self, node, visited_children) -> UnionDeclaration:
        fragile, union, name, lbrace, body, rbrace = visited_children
        assert fragile.type == 'fragile'
        assert union.type == 'union'
        assert name.type == 'identifier'
        name = name.data
        assert lbrace.type == '{'
        assert rbrace.type == '}'
        name = name.data
        return UnionDeclaration(name, None, body)

    def visit_FunctionDeclaration(self, node, visited_children) -> FunctionDeclaration:


@@ 343,15 403,17 @@ class ASTBuilder(NodeVisitor):

    def visit_VariableDefinition(self, node, visited_children) -> VariableDefinition:
        type, name, eq, value, semi = visited_children
        assert name.type == 'identifier'
        name = name.data
        assert eq.type == '='
        assert semi.type == ';'
        name = name.data
        return VariableDefinition(name, type, value)

    def visit_VariableDeclaration(self, node, visited_children) -> VariableDeclaration:
        type, name, semi = visited_children
        assert semi.type == ';'
        assert name.type == 'identifier'
        name = name.data
        assert semi.type == ';'
        return VariableDeclaration(name, type)

    def visit_FunctionDefinition(self, node, visited_children) -> FunctionDefinition:


@@ 363,9 425,53 @@ class ASTBuilder(NodeVisitor):
        assert name.type == 'identifier'
        name = name.data
        assert lparen.type == '('
        if args is None:
            args = []
        assert rparen.type == ')'
        return FunctionDeclaration(name, return_type, args)

    def visit_SignatureArguments(self, node, visited_children) -> List[VariableDeclaration]:
        first_type, first_name, rest, comma = visited_children
        result = [VariableDeclaration(first_name.data, first_type)]
        for comma, ty, name in rest:
            result.append(VariableDeclaration(name.data, ty))
        return result

    def visit_IfStatement(self, node, visited_children):
        kwd, lparen, condition, rparen, then, els = visited_children
        assert kwd.type == 'if'
        assert lparen.type == '('
        assert rparen.type == ')'
        if els is not None:
            kwd, els = els
            assert kwd.type == 'else'
        return IfStatement(condition, then, els)

    def visit_ReturnStatement(self, node, visited_children):
        ret, body, semi = visited_children
        assert ret.type == 'return'
        assert semi.type == ';'
        return ReturnStatement(body)

    def visit_DirectAssignmentBody(self, node, visited_children):
        dest, eq, value = visited_children
        assert eq.type == '='
        return DirectAssignment(dest, value)

    def visit_UpdateAssignmentBody(self, node, visited_children):
        dest, op, value = visited_children
        return UpdateAssignment(dest, op.type, value)

    def visit_AssignmentStatement(self, node, visited_children):
        assignment, semi = visited_children
        assert semi.type == ';'
        return assignment

    def visit_ExpressionStatement(self, node, visited_children):
        expression, semi = visited_children
        assert semi.type == ';'
        return ExpressionStatement(expression)

    def visit_BasicType(self, node, visited_children) -> Type:
        while isinstance(visited_children, list) and len(visited_children) == 1:
            visited_children = visited_children[0]


@@ 379,17 485,23 @@ class ASTBuilder(NodeVisitor):
            else:
                category, name = visited_children
                category = category.type
                assert name.type == 'identifier'
                name = name.data
                return BasicType(f"{category} {name}")
        return BasicType(visited_children.type)

    def visit_ConstType(self, node, visited_children) -> ConstType:
        const, contents = visited_children
        assert const.type == 'const'
        return ConstType(contents)

    def visit_FunctionType(self, node, visited_children):
        raise NotImplementedError('function types')

    def visit_ArrayType(self, node, visited_children) -> ArrayType:
        contents, lbracket, size, rbracket = visited_children
        assert lbracket.type == '['
        assert rbracket.type == ']'
        # TODO don't explode on nontrivial expression
        while isinstance(size, list):
            size = size[0]
        return ArrayType(contents, size)

    def visit_PointerType(self, node, visited_children) -> PointerType:


@@ 397,6 509,12 @@ class ASTBuilder(NodeVisitor):
        assert splat.type == '*'
        return PointerType(contents)

    def visit_Block(self, node, visited_children) -> List[Expression]:
        lbrace, body, rbrace = visited_children
        assert lbrace.type == '{'
        assert rbrace.type == '}'
        return body

    def visit_AtomicExpression(self, node, visited_children) -> Expression:
        if isinstance(visited_children, list) and len(visited_children) == 3:
            lparen, body, rparen = visited_children


@@ 414,17 532,106 @@ class ASTBuilder(NodeVisitor):
            return ConstantExpression(body.type)
        raise NotImplementedError()

    def visit_StructPointerElementSuffix(self, node, visited_children):
        separator, element = visited_children
        assert separator.type == '->'
        return lambda base: StructPointerElementExpression(base, element.data)

    def visit_CommasExpressionList(self, node, visited_children):
        first, rest, comma = visited_children
        result = [first]
        for comma, next in rest:
            result.append(next)
        return result

    def visit_FunctionCallSuffix(self, node, visited_children):
        lparen, args, rparen = visited_children
        assert lparen.type == '('
        assert rparen.type == ')'
        if args is None:
            args = []
        return lambda base: FunctionCallExpression(base, args)

    def visit_ArrayIndexSuffix(self, node, visited_children):
        lbracket, index, rbracket = visited_children
        assert lbracket.type == '['
        assert rbracket.type == ']'
        return lambda base: ArrayIndexExpression(base, index)

    def visit_ObjectExpression(self, node, visited_children) -> Expression:
        if isinstance(visited_children, list):
            base, suffix = visited_children[0]
            if len(suffix) > 0:
                for suffix in suffix:
                    base = suffix(base)
            return base
        raise NotImplementedError('array/struct literals')

    def visit_NegativeExpression(self, node, visited_children):
        minus, body = visited_children
        assert minus.type == '-'
        return NegativeExpression(body)

    def visit_AddressOfExpression(self, node, visited_children):
        ampersand, body = visited_children
        assert ampersand.type == '&'
        return AddressOfExpression(body)

    def visit_LogicalNotExpression(self, node, visited_children):
        bang, body = visited_children
        assert bang.type == '!'
        return LogicalNotExpression(body)

    def visit_SizeofExpression(self, node, visited_children):
        sizeof, argument = visited_children[0]
        assert sizeof.type == 'sizeof'
        return SizeofExpression(argument)

    def visit_TermExpression(self, node, visited_children) -> Expression:
        base, suffix = visited_children
        if suffix is not None:
            for op, factor in suffix:
                if op.type == '*':
                    base = MultiplyExpression(base, factor)
                else:
                    raise NotImplementedError('term suffix ' + op)
        return base

    def visit_ArithmeticExpression(self, node, visited_children) -> Expression:
        base, suffix = visited_children
        if suffix is not None:
            for op, term in suffix:
                if op.type == '+':
                    base = AddExpression(base, term)
                else:
                    raise NotImplementedError('arithmetic suffix ' + op)
        return base

    def visit_GreaterEqExpression(self, node, visited_children):
        value1, op, value2 = visited_children
        assert op.type == '>='
        return ComparisonExpression(value1, '>=', value2)

    def visit_LessEqExpression(self, node, visited_children):
        value1, op, value2 = visited_children
        assert op.type == '<='
        return ComparisonExpression(value1, '<=', value2)

    def generic_visit(self, node, visited_children):
        """ The generic visit method. """
        if not visited_children:
            if len(node.text) == 0:
                return []
            if len(node.text) == 1:
                return node.text[0]
            raise ValueError('just a node: ' + str(node))
        if len(visited_children) == 1:
        if isinstance(node.expr, TokenMatcher):
            return node.text[0]
        if isinstance(node.expr, OneOf):
            return visited_children[0]
        if isinstance(node.expr, Optional):
            if len(visited_children) == 0:
                return None
            return visited_children[0]
        return visited_children
        if isinstance(node.expr, Sequence) and node.expr.name != '':
            raise NotImplementedError('visit for sequence ' + str(node.expr))
        if isinstance(node.expr, Compound):
            return visited_children
        print(node.expr)
        return super(ASTBuilder, self).generic_visit(node, visited_children)


def build_ast(parse_tree, include_dirs):

M tests/test_ast.py => tests/test_ast.py +352 -4
@@ 1,12 1,43 @@
import dataclasses
import unittest

from crowbar_reference_compiler import build_ast, parse_header, scan
from crowbar_reference_compiler.ast import ArrayType, BasicType, ConstantExpression, EnumDeclaration, HeaderFile, \
    PointerType, StructDeclaration, UnionDeclaration, VariableDeclaration, VariableExpression
from crowbar_reference_compiler import build_ast, parse_header, parse_implementation, scan
from crowbar_reference_compiler.ast import (
    AddExpression,
    AddressOfExpression,
    ArrayIndexExpression,
    ArrayType,
    DirectAssignment,
    BasicType,
    ComparisonExpression,
    ConstType,
    ConstantExpression,
    EnumDeclaration,
    ExpressionStatement,
    FunctionCallExpression,
    FunctionDeclaration,
    FunctionDefinition,
    HeaderFile,
    IfStatement,
    ImplementationFile,
    LogicalNotExpression,
    MultiplyExpression,
    NegativeExpression,
    PointerType,
    ReturnStatement,
    SizeofExpression,
    StructDeclaration,
    StructPointerElementExpression,
    UnionDeclaration,
    UpdateAssignment,
    VariableDeclaration,
    VariableDefinition,
    VariableExpression
)


class TestAST(unittest.TestCase):
    def test_kitchen_sink(self):
    def test_type_kitchen_sink(self):
        code = r"""
struct normal {
    bool fake;


@@ 50,5 81,322 @@ fragile union not_robust {
        self.assertEqual(decls, HeaderFile([], [normal, ope, sample, robust, not_robust]))


class TestRealCode(unittest.TestCase):
    str_hro_code = r"""
struct str {
    (uint8[size])* str;
    uintsize len;
    uintsize size;
}

struct str *str_create();
void str_free(struct str *str);
void str_reset(struct str *str);
intsize str_append_ch(struct str *str, uint32 ch);
"""

    unicode_hro_code = r"""
// Technically UTF-8 supports up to 6 byte codepoints, but Unicode itself
// doesn't really bother with more than 4.
const intsize UTF8_MAX_SIZE = 4;

const uint8 UTF8_INVALID = 0x80;

/**
 * Grabs the next UTF-8 character and advances the string pointer
 */
uint32 utf8_decode(((const uint8)*)* str);

/**
 * Encodes a character as UTF-8 and returns the length of that character.
 */
intsize utf8_encode(uint8 *str, uint32 ch);

/**
 * Returns the size of the next UTF-8 character
 */
intsize utf8_size((const uint8)* str);

/**
 * Returns the size of a UTF-8 character
 */
intsize utf8_chsize(uint32 ch);

/**
 * Reads and returns the next character from the file.
 */
uint32 utf8_fgetch(struct FILE *f);

/**
 * Writes this character to the file and returns the number of bytes written.
 */
intsize utf8_fputch(struct FILE *f, uint32 ch);
"""

    string_cro_code = r"""
//include "stdlib.hro";
//include "stdint.hro";
include "str.hro";
include "unicode.hro";

bool ensure_capacity(struct str *str, intsize len) {
    if (len + 1 >= str->size) {
        (uint8[str->size * 2])* new = realloc(str->str, str->size * 2);
        if (!new) {
            return false;
        }
        str->str = new;
        str->size *= 2;
    }
    return true;
}

struct str *str_create() {
    struct str *str = calloc(1, sizeof(struct str));
    str->str = malloc(16);
    str->size = 16;
    str->len = 0;
    str->str[0] = '\0';
    return str;
}

void str_free(struct str *str) {
    if (!str) {
        return;
    }
    free(str->str);
    free(str);
}

intsize str_append_ch(struct str *str, uint32 ch) {
    intsize size = utf8_chsize(ch);
    if (size <= 0) {
        return -1;
    }
    if (!ensure_capacity(str, str->len + size)) {
        return -1;
    }
    utf8_encode(&str->str[str->len], ch);
    str->len += size;
    str->str[str->len] = '\0';
    return size;
}
"""

    def test_str_hro(self):
        code = self.str_hro_code
        tokens = scan(code)
        parse_tree = parse_header(tokens)
        ast = build_ast(parse_tree, [])
        struct_str = StructDeclaration('str', [
            VariableDeclaration('str', PointerType(ArrayType(BasicType('uint8'), VariableExpression('size')))),
            VariableDeclaration('len', BasicType('uintsize')),
            VariableDeclaration('size', BasicType('uintsize')),
        ])

        pointer_to_struct_str = PointerType(BasicType('struct str'))

        str_create = FunctionDeclaration('str_create', pointer_to_struct_str, [])
        str_free = FunctionDeclaration('str_free', BasicType('void'), [
            VariableDeclaration('str', pointer_to_struct_str)
        ])
        str_reset = FunctionDeclaration('str_reset', BasicType('void'), [
            VariableDeclaration('str', pointer_to_struct_str)
        ])
        str_append_ch = FunctionDeclaration('str_append_ch', BasicType('intsize'), [
            VariableDeclaration('str', pointer_to_struct_str),
            VariableDeclaration('ch', BasicType('uint32'))
        ])

        self.assertEqual(ast, HeaderFile([], [struct_str, str_create, str_free, str_reset, str_append_ch]))

    def test_unicode_hro(self):
        code = self.unicode_hro_code
        tokens = scan(code)
        parse_tree = parse_header(tokens)
        ast = build_ast(parse_tree, [])

        utf8_max_size = VariableDefinition('UTF8_MAX_SIZE', ConstType(BasicType('intsize')), ConstantExpression('4'))
        utf8_invalid = VariableDefinition('UTF8_INVALID', ConstType(BasicType('uint8')), ConstantExpression('0x80'))
        utf8_decode = FunctionDeclaration('utf8_decode', BasicType('uint32'), [
            VariableDeclaration('str', PointerType(PointerType(ConstType(BasicType('uint8'))))),
        ])
        utf8_encode = FunctionDeclaration('utf8_encode', BasicType('intsize'), [
            VariableDeclaration('str', PointerType(BasicType('uint8'))),
            VariableDeclaration('ch', BasicType('uint32')),
        ])
        utf8_size = FunctionDeclaration('utf8_size', BasicType('intsize'), [
            VariableDeclaration('str', PointerType(ConstType(BasicType('uint8')))),
        ])
        utf8_chsize = FunctionDeclaration('utf8_chsize', BasicType('intsize'), [
            VariableDeclaration('ch', BasicType('uint32')),
        ])
        utf8_fgetch = FunctionDeclaration('utf8_fgetch', BasicType('uint32'), [
            VariableDeclaration('f', PointerType(BasicType('struct FILE'))),
        ])
        utf8_fputch = FunctionDeclaration('utf8_fputch', BasicType('intsize'), [
            VariableDeclaration('f', PointerType(BasicType('struct FILE'))),
            VariableDeclaration('ch', BasicType('uint32')),
        ])

        self.assertEqual(ast, HeaderFile([], [utf8_max_size, utf8_invalid, utf8_decode, utf8_encode, utf8_size,
                                              utf8_chsize, utf8_fgetch, utf8_fputch]))

    def test_string_cro(self):
        import tempfile

        code = self.string_cro_code
        tokens = scan(code)
        parse_tree = parse_implementation(tokens)
        with tempfile.TemporaryDirectory() as include_dir:
            with open(f"{include_dir}/str.hro", 'w', encoding='utf-8') as f:
                f.write(self.str_hro_code)
            with open(f"{include_dir}/unicode.hro", 'w', encoding='utf-8') as f:
                f.write(self.unicode_hro_code)
            ast = build_ast(parse_tree, [include_dir])

        included_str_hro = build_ast(parse_header(scan(self.str_hro_code)), [])
        included_unicode_hro = build_ast(parse_header(scan(self.unicode_hro_code)), [])

        expected = ImplementationFile([included_str_hro, included_unicode_hro], [
            FunctionDefinition('ensure_capacity', BasicType('bool'), [
                VariableDeclaration('str', PointerType(BasicType('struct str'))),
                VariableDeclaration('len', BasicType('intsize')),
            ], [
                IfStatement(ComparisonExpression(
                    AddExpression(VariableExpression('len'), ConstantExpression('1')),
                    '>=',
                    StructPointerElementExpression(VariableExpression('str'), 'size')
                ), [
                    VariableDefinition(
                        'new',
                        PointerType(ArrayType(
                            BasicType('uint8'),
                            MultiplyExpression(
                                StructPointerElementExpression(VariableExpression('str'), 'size'),
                                ConstantExpression('2')
                            )
                        )),
                        FunctionCallExpression(VariableExpression('realloc'), [
                            StructPointerElementExpression(VariableExpression('str'), 'str'),
                            MultiplyExpression(
                                StructPointerElementExpression(VariableExpression('str'), 'size'),
                                ConstantExpression('2')
                            )
                        ])
                    ),
                    IfStatement(LogicalNotExpression(VariableExpression('new')), [
                        ReturnStatement(ConstantExpression('false'))
                    ], None),
                    DirectAssignment(
                        StructPointerElementExpression(VariableExpression('str'), 'str'),
                        VariableExpression('new'),
                    ),
                    UpdateAssignment(
                        StructPointerElementExpression(VariableExpression('str'), 'size'),
                        '*=',
                        ConstantExpression('2'),
                    )
                ], None),
                ReturnStatement(ConstantExpression('true')),
            ]),
            FunctionDefinition('str_create', PointerType(BasicType('struct str')), [], [
                VariableDefinition(
                    'str',
                    PointerType(BasicType('struct str')),
                    FunctionCallExpression(
                        VariableExpression('calloc'),
                        [
                            ConstantExpression('1'),
                            SizeofExpression(BasicType('struct str')),
                        ]
                    )
                ),
                DirectAssignment(
                    StructPointerElementExpression(VariableExpression('str'), 'str'),
                    FunctionCallExpression(VariableExpression('malloc'), [ConstantExpression('16')]),
                ),
                DirectAssignment(
                    StructPointerElementExpression(VariableExpression('str'), 'size'),
                    ConstantExpression('16'),
                ),
                DirectAssignment(
                    StructPointerElementExpression(VariableExpression('str'), 'len'),
                    ConstantExpression('0'),
                ),
                DirectAssignment(
                    ArrayIndexExpression(
                        StructPointerElementExpression(VariableExpression('str'), 'str'),
                        ConstantExpression('0'),
                    ),
                    ConstantExpression(r"'\0'"),
                ),
                ReturnStatement(VariableExpression('str')),
            ]),
            FunctionDefinition('str_free', BasicType('void'), [
                VariableDeclaration('str', PointerType(BasicType('struct str')))
            ], [
                IfStatement(LogicalNotExpression(VariableExpression('str')), [
                    ReturnStatement(None)
                ], None),
                ExpressionStatement(FunctionCallExpression(
                    VariableExpression('free'),
                    [StructPointerElementExpression(VariableExpression('str'), 'str')]
                )),
                ExpressionStatement(FunctionCallExpression(
                    VariableExpression('free'),
                    [VariableExpression('str')]
                ))
            ]),
            FunctionDefinition('str_append_ch', BasicType('intsize'), [
                VariableDeclaration('str', PointerType(BasicType('struct str'))),
                VariableDeclaration('ch', BasicType('uint32')),
            ], [
                VariableDefinition('size', BasicType('intsize'), FunctionCallExpression(
                    VariableExpression('utf8_chsize'),
                    [VariableExpression('ch')]
                )),
                IfStatement(ComparisonExpression(VariableExpression('size'), '<=', ConstantExpression('0')), [
                    ReturnStatement(NegativeExpression(ConstantExpression('1')))
                ], None),
                IfStatement(LogicalNotExpression(FunctionCallExpression(
                    VariableExpression('ensure_capacity'),
                    [VariableExpression('str'), AddExpression(
                        StructPointerElementExpression(
                            VariableExpression('str'),
                            'len'
                        ),
                        VariableExpression('size'),
                    )]
                )), [
                    ReturnStatement(NegativeExpression(ConstantExpression('1')))
                ], None),
                ExpressionStatement(FunctionCallExpression(VariableExpression('utf8_encode'), [
                    AddressOfExpression(ArrayIndexExpression(
                        StructPointerElementExpression(VariableExpression('str'), 'str'),
                        StructPointerElementExpression(VariableExpression('str'), 'len'),
                    )),
                    VariableExpression('ch'),
                ])),
                UpdateAssignment(
                    StructPointerElementExpression(VariableExpression('str'), 'len'),
                    '+=',
                    VariableExpression('size')
                ),
                DirectAssignment(
                    ArrayIndexExpression(
                        StructPointerElementExpression(VariableExpression('str'), 'str'),
                        StructPointerElementExpression(VariableExpression('str'), 'len'),
                    ),
                    ConstantExpression(r"'\0'"),
                ),
                ReturnStatement(VariableExpression('size'))
            ])
        ])

        self.assertDictEqual(dataclasses.asdict(ast), dataclasses.asdict(expected))
        self.assertEqual(ast, expected)


if __name__ == '__main__':
    unittest.main()

M tests/test_parsing.py => tests/test_parsing.py +1 -66
@@ 1,73 1,8 @@
import unittest

from crowbar_reference_compiler import parse_header, parse_implementation, scan
from crowbar_reference_compiler import parse_header, scan


class TestParsing(unittest.TestCase):
    def test_basic(self):
        print(parse_header(scan("int8 x();")))

    def test_scdoc_str(self):
        # adapted from https://git.sr.ht/~sircmpwn/scdoc/tree/master/include/str.h
        print(parse_header(scan(r"""
struct str {
    (uint8[size])* str;
    uintsize len;
    uintsize size;
}

struct str *str_create();
void str_free(struct str *str);
void str_reset(struct str *str);
intsize str_append_ch(struct str *str, uint32 ch);
""")))
        # adapted from https://git.sr.ht/~sircmpwn/scdoc/tree/master/src/string.c
        print(parse_implementation(scan(r"""
include "stdlib.hro";
include "stdint.hro";
include "str.hro";
include "unicode.hro";

bool ensure_capacity(struct str *str, intsize len) {
    if (len + 1 >= str->size) {
        (uint8[str->size * 2])* new = realloc(str->str, str->size * 2);
        if (!new) {
            return false;
        }
        str->str = new;
        str->size *= 2;
    }
    return true;
}

struct str *str_create() {
    struct str *str = calloc(1, sizeof(struct str));
    str->str = malloc(16);
    str->size = 16;
    str->len = 0;
    str->str[0] = '\0';
    return str;
}

void str_free(struct str *str) {
    if (!str) {
        return;
    }
    free(str->str);
    free(str);
}

intsize str_append_ch(struct str *str, uint32 ch) {
    intsize size = utf8_chsize(ch);
    if (size <= 0) {
        return -1;
    }
    if (!ensure_capacity(str, str->len + size)) {
        return -1;
    }
    utf8_encode(&str->str[str->len], ch);
    str->len += size;
    str->str[str->len] = '\0';
    return size;
}
""")))