from dataclasses import dataclass
from pathlib import Path
import typing
from typing import ClassVar, List, Tuple, Union
from parsimonious import NodeVisitor # type: ignore
from parsimonious.expressions import Compound, OneOf, Optional, Sequence, TokenMatcher, ZeroOrMore # type: ignore
from .scanner import scan
from .parser import parse_header
@dataclass
class Type:
def size_bytes(self, declarations: List['Declaration']) -> int:
raise NotImplementedError('type.size_bytes() on ' + str(type(self)) + ' not implemented')
@dataclass
class Expression:
def type(self, declarations: List['Declaration']) -> Type:
raise NotImplementedError('expression.type() on ' + str(type(self)) + ' not implemented')
@dataclass
class ConstantExpression(Expression):
value: str
def type(self, _: List['Declaration']) -> Type:
if self.value.startswith('"'):
return PointerType(ConstType(BasicType('char')))
elif self.value.startswith("'"):
return BasicType('char')
elif self.value in ['true', 'false']:
return BasicType('bool')
elif '.' in self.value:
return BasicType('float?') # TODO infer size
else:
return BasicType('int?') # TODO infer size and signedness
@dataclass
class VariableExpression(Expression):
name: str
def type(self, declarations: List['Declaration']) -> Type:
for decl in declarations:
if decl.name == self.name:
if isinstance(decl, VariableDeclaration):
return decl.type
elif isinstance(decl, VariableDefinition):
return decl.type
elif isinstance(decl, FunctionDeclaration) or isinstance(decl, FunctionDefinition):
return FunctionType(decl.return_type, [arg.type for arg in decl.args])
raise KeyError('unknown variable ' + self.name)
@dataclass
class AddExpression(Expression):
term1: Expression
term2: Expression
@dataclass
class SubtractExpression(Expression):
term1: Expression
term2: Expression
@dataclass
class MultiplyExpression(Expression):
factor1: Expression
factor2: Expression
@dataclass
class StructPointerElementExpression(Expression):
base: Expression
element: str
def type(self, declarations: List['Declaration']) -> Type:
base_type = self.base.type(declarations)
assert isinstance(base_type, PointerType)
assert isinstance(base_type.target, BasicType)
hopefully_struct, struct_name = base_type.target.name.split(' ')
assert hopefully_struct == 'struct'
for decl in declarations:
if isinstance(decl, StructDeclaration) and decl.name == struct_name:
if decl.fields is None:
raise KeyError('struct ' + struct_name + ' is opaque')
for elem in decl.fields:
if elem.name == self.element:
return elem.type
raise KeyError('element ' + self.element + ' not found in struct ' + struct_name)
raise KeyError('struct ' + struct_name + ' not found')
@dataclass
class ArrayIndexExpression(Expression):
array: Expression
index: Expression
@dataclass
class FunctionCallExpression(Expression):
function: Expression
arguments: List[Expression]
@dataclass
class LogicalNotExpression(Expression):
body: Expression
@dataclass
class NegativeExpression(Expression):
body: Expression
@dataclass
class AddressOfExpression(Expression):
body: Expression
@dataclass
class SizeofExpression(Expression):
body: Union[Type, Expression]
@dataclass
class ComparisonExpression(Expression):
value1: Expression
op: str
value2: Expression
@dataclass
class BasicType(Type):
name: str
def size_bytes(self, declarations: List['Declaration']) -> int:
if self.name == 'uint8':
return 1
elif self.name == 'uintsize':
return 8
elif self.name.startswith('struct'):
_, struct_name = self.name.split(' ')
for decl in declarations:
if isinstance(decl, StructDeclaration) and decl.name == struct_name:
if decl.fields is None:
raise KeyError('struct ' + struct_name + ' is opaque')
return sum(field.type.size_bytes(declarations) for field in decl.fields)
raise NotImplementedError('size of ' + str(self) + ' not yet found')
@dataclass
class ConstType(Type):
target: Type
@dataclass
class PointerType(Type):
target: Type
def size_bytes(self, declarations: List['Declaration']) -> int:
return 8 # TODO figure out 32 bit vs 64 bit
@dataclass
class ArrayType(Type):
contents: Type
size: Expression
@dataclass
class FunctionType(Type):
return_type: Type
args: List[Type]
@dataclass
class HeaderFileElement:
pass
@dataclass
class ImplementationFileElement:
pass
@dataclass
class Statement:
pass
@dataclass
class EmptyStatement(Statement):
pass
@dataclass
class FragileStatement(Statement):
body: Statement
@dataclass
class ExpressionStatement(Statement):
body: Expression
@dataclass
class IfStatement(Statement):
condition: Expression
then: List[Statement]
els: typing.Optional[List[Statement]]
@dataclass
class SwitchStatement(Statement):
expression: Expression
body: List[Union[typing.Optional[Expression], Statement]]
@dataclass
class WhileStatement(Statement):
condition: Expression
body: List[Statement]
@dataclass
class DoWhileStatement(Statement):
condition: Expression
body: List[Statement]
@dataclass
class Declaration:
name: str
@dataclass
class VariableDeclaration(Declaration, HeaderFileElement):
"""Represents the declaration of a variable."""
type: Type
@dataclass
class VariableDefinition(Declaration, HeaderFileElement, ImplementationFileElement, Statement):
"""Represents the definition of a variable."""
type: Type
value: Expression
@dataclass
class AssignmentStatement(Statement):
pass
@dataclass
class ForStatement(Statement):
init: List[VariableDefinition]
condition: Expression
update: List[AssignmentStatement]
@dataclass
class ContinueStatement(Statement):
pass
@dataclass
class BreakStatement(Statement):
pass
@dataclass
class ReturnStatement(Statement):
body: typing.Optional[Expression]
@dataclass
class DirectAssignment(AssignmentStatement):
destination: Expression
value: Expression
@dataclass
class UpdateAssignment(AssignmentStatement):
destination: Expression
operation: str
value: Expression
def deconstruct(self) -> DirectAssignment:
if self.operation == '+=':
return DirectAssignment(self.destination, AddExpression(self.destination, self.value))
elif self.operation == '*=':
return DirectAssignment(self.destination, MultiplyExpression(self.destination, self.value))
else:
raise NotImplementedError('UpdateAssignment deconstruct with ' + self.operation)
@dataclass
class CrementAssignment(AssignmentStatement):
destination: Expression
operation: str
@dataclass
class StructDeclaration(Declaration, HeaderFileElement, ImplementationFileElement):
"""Represents the declaration of a struct type."""
fields: typing.Optional[List[VariableDeclaration]]
@dataclass
class EnumDeclaration(Declaration, HeaderFileElement, ImplementationFileElement):
"""Represents the declaration of an enum type."""
values: List[Tuple[str, typing.Optional[int]]]
@dataclass
class UnionDeclaration(Declaration, HeaderFileElement, ImplementationFileElement):
"""Represents the declaration of a union type."""
tag: typing.Optional[VariableDeclaration]
cases: Union[List[VariableDeclaration], List[Tuple[Expression, typing.Optional[VariableDeclaration]]]]
@dataclass
class FunctionDeclaration(Declaration, HeaderFileElement):
"""Represents the declaration of a function."""
return_type: Type
args: List[VariableDeclaration]
@dataclass
class FunctionDefinition(Declaration, HeaderFileElement, ImplementationFileElement):
"""Represents the definition of a function."""
return_type: Type
args: List[VariableDeclaration]
body: List[Statement]
@dataclass
class HeaderFile:
grammar: ClassVar[str] = "HeaderFile <- IncludeStatement* HeaderFileElement+"
includes: List['HeaderFile']
contents: List[HeaderFileElement]
def get_declarations(self) -> List[Declaration]:
included_declarations = [x.get_declarations() for x in self.includes]
own_declarations: List[Declaration] = [x for x in self.contents if isinstance(x, Declaration)]
all_declarations = included_declarations + [own_declarations]
return [x for l in all_declarations for x in l]
@dataclass
class ImplementationFile:
includes: List[HeaderFile]
contents: List[ImplementationFileElement]
def get_declarations(self) -> List[Declaration]:
included_declarations = [x.get_declarations() for x in self.includes]
own_declarations: List[Declaration] = [x for x in self.contents if isinstance(x, Declaration)]
all_declarations = included_declarations + [own_declarations]
return [x for l in all_declarations for x in l]
# noinspection PyPep8Naming,PyMethodMayBeStatic,PyUnusedLocal
class ASTBuilder(NodeVisitor):
def __init__(self, include_folders):
self.include_folders = include_folders
def visit_HeaderFile(self, node, visited_children) -> HeaderFile:
includes, elements = visited_children
return HeaderFile(includes, elements)
def visit_ImplementationFile(self, node, visited_children) -> ImplementationFile:
includes, elements = visited_children
return ImplementationFile(includes, elements)
def visit_IncludeStatement(self, node, visited_children) -> HeaderFile:
include, included_header, semicolon = visited_children
assert include.type == 'include'
assert included_header.type == 'string_literal'
included_header = included_header.data.strip('"')
assert semicolon.type == ';'
for include_folder in self.include_folders:
header = Path(include_folder) / included_header
if header.exists():
with open(header, 'r', encoding='utf-8') as header_file:
header_text = header_file.read()
header_parse_tree = parse_header(scan(header_text))
return self.visit(header_parse_tree)
raise FileNotFoundError(included_header)
def visit_NormalStructDefinition(self, node, visited_children) -> StructDeclaration:
struct, name, lbrace, fields, rbrace = visited_children
assert struct.type == 'struct'
assert name.type == 'identifier'
name = name.data
assert lbrace.type == '{'
assert rbrace.type == '}'
return StructDeclaration(name, fields)
def visit_OpaqueStructDefinition(self, node, visited_children) -> StructDeclaration:
opaque, struct, name, semi = visited_children
assert opaque.type == 'opaque'
assert struct.type == 'struct'
assert name.type == 'identifier'
name = name.data
assert semi.type == ';'
return StructDeclaration(name, None)
def visit_EnumDefinition(self, node, visited_children) -> EnumDeclaration:
enum, name, lbrace, first_member, extra_members, trailing_comma, rbrace = visited_children
assert enum.type == 'enum'
assert name.type == 'identifier'
name = name.data
assert lbrace.type == '{'
assert rbrace.type == '}'
values = [first_member]
for _, v in extra_members:
values.append(v)
return EnumDeclaration(name, values)
def visit_EnumMember(self, node, visited_children) -> Tuple[str, typing.Optional[Expression]]:
name, equals_value = visited_children
assert name.type == 'identifier'
name = name.data
if equals_value is None:
return name, None
_, value = equals_value
return name, value
def visit_RobustUnionDefinition(self, node, visited_children) -> UnionDeclaration:
union, name, lbrace, tag, body, rbrace = visited_children
assert union.type == 'union'
assert name.type == 'identifier'
name = name.data
assert lbrace.type == '{'
assert rbrace.type == '}'
expected_tagname, body = body
if tag.name != expected_tagname:
raise NameError(f"tag {tag} does not match switch argument {expected_tagname}")
if not isinstance(body, list):
body = [body]
return UnionDeclaration(name, tag, body)
def visit_UnionBody(self, node, visited_children) -> Tuple[str, List[Tuple[Expression, typing.Optional[VariableDeclaration]]]]:
switch, lparen, tag, rparen, lbrace, body, rbrace = visited_children
assert switch.type == 'switch'
assert lparen.type == '('
assert rparen.type == ')'
assert lbrace.type == '{'
assert rbrace.type == '}'
return tag.data, body
def visit_UnionBodySet(self, node, visited_children) -> Tuple[Expression, typing.Optional[VariableDeclaration]]:
cases, var = visited_children
if isinstance(cases, list):
cases = cases[0]
if isinstance(var, VariableDeclaration):
return cases, var
else:
return cases, None
def visit_CaseSpecifier(self, node, visited_children) -> Expression:
while isinstance(visited_children, list) and len(visited_children) == 1:
visited_children = visited_children[0]
# TODO don't explode on 'default:'
case, expr, colon = visited_children
return expr
def visit_FragileUnionDefinition(self, node, visited_children) -> UnionDeclaration:
fragile, union, name, lbrace, body, rbrace = visited_children
assert fragile.type == 'fragile'
assert union.type == 'union'
assert name.type == 'identifier'
name = name.data
assert lbrace.type == '{'
assert rbrace.type == '}'
return UnionDeclaration(name, None, body)
def visit_FunctionDeclaration(self, node, visited_children) -> FunctionDeclaration:
signature, semi = visited_children
assert semi.type == ';'
return signature
def visit_VariableDefinition(self, node, visited_children) -> VariableDefinition:
type, name, eq, value, semi = visited_children
assert name.type == 'identifier'
name = name.data
assert eq.type == '='
assert semi.type == ';'
return VariableDefinition(name, type, value)
def visit_VariableDeclaration(self, node, visited_children) -> VariableDeclaration:
type, name, semi = visited_children
assert name.type == 'identifier'
name = name.data
assert semi.type == ';'
return VariableDeclaration(name, type)
def visit_FunctionDefinition(self, node, visited_children) -> FunctionDefinition:
signature, body = visited_children
return FunctionDefinition(signature.name, signature.return_type, signature.args, body)
def visit_FunctionSignature(self, node, visited_children) -> FunctionDeclaration:
return_type, name, lparen, args, rparen = visited_children
assert name.type == 'identifier'
name = name.data
assert lparen.type == '('
if args is None:
args = []
assert rparen.type == ')'
return FunctionDeclaration(name, return_type, args)
def visit_SignatureArguments(self, node, visited_children) -> List[VariableDeclaration]:
first_type, first_name, rest, comma = visited_children
result = [VariableDeclaration(first_name.data, first_type)]
for comma, ty, name in rest:
result.append(VariableDeclaration(name.data, ty))
return result
def visit_IfStatement(self, node, visited_children):
kwd, lparen, condition, rparen, then, els = visited_children
assert kwd.type == 'if'
assert lparen.type == '('
assert rparen.type == ')'
if els is not None:
kwd, els = els
assert kwd.type == 'else'
return IfStatement(condition, then, els)
def visit_WhileStatement(self, node, visited_children):
kwd, lparen, condition, rparen, body = visited_children
assert kwd.type == 'while'
assert lparen.type == '('
assert rparen.type == ')'
return WhileStatement(condition, body)
def visit_ReturnStatement(self, node, visited_children):
ret, body, semi = visited_children
assert ret.type == 'return'
assert semi.type == ';'
return ReturnStatement(body)
def visit_DirectAssignmentBody(self, node, visited_children):
dest, eq, value = visited_children
assert eq.type == '='
return DirectAssignment(dest, value)
def visit_UpdateAssignmentBody(self, node, visited_children):
dest, op, value = visited_children
return UpdateAssignment(dest, op.type, value)
def visit_AssignmentStatement(self, node, visited_children):
assignment, semi = visited_children
assert semi.type == ';'
return assignment
def visit_ExpressionStatement(self, node, visited_children):
expression, semi = visited_children
assert semi.type == ';'
return ExpressionStatement(expression)
def visit_BasicType(self, node, visited_children) -> Type:
while isinstance(visited_children, list) and len(visited_children) == 1:
visited_children = visited_children[0]
if isinstance(visited_children, list):
if len(visited_children) == 3:
# parenthesized!
lparen, ty, rparen = visited_children
assert lparen.type == '('
assert rparen.type == ')'
return ty
else:
category, name = visited_children
category = category.type
assert name.type == 'identifier'
name = name.data
return BasicType(f"{category} {name}")
return BasicType(visited_children.type)
def visit_ConstType(self, node, visited_children) -> ConstType:
const, contents = visited_children
assert const.type == 'const'
return ConstType(contents)
def visit_FunctionType(self, node, visited_children):
raise NotImplementedError('function types')
def visit_ArrayType(self, node, visited_children) -> ArrayType:
contents, lbracket, size, rbracket = visited_children
assert lbracket.type == '['
assert rbracket.type == ']'
return ArrayType(contents, size)
def visit_PointerType(self, node, visited_children) -> PointerType:
contents, splat = visited_children
assert splat.type == '*'
return PointerType(contents)
def visit_Block(self, node, visited_children) -> List[Expression]:
lbrace, body, rbrace = visited_children
assert lbrace.type == '{'
assert rbrace.type == '}'
return body
def visit_AtomicExpression(self, node, visited_children) -> Expression:
while isinstance(visited_children, list) and len(visited_children) == 1:
visited_children = visited_children[0]
if isinstance(visited_children, list) and len(visited_children) == 3:
lparen, body, rparen = visited_children
assert lparen.type == '('
assert rparen.type == ')'
return body
body = visited_children
while isinstance(body, list):
body = body[0]
if body.type == 'identifier':
return VariableExpression(body.data)
if body.type == 'constant':
return ConstantExpression(body.data)
if body.type in ['true', 'false']:
return ConstantExpression(body.type)
if body.type == 'string_literal':
return ConstantExpression(body.data)
raise NotImplementedError('atomic expression ' + repr(body))
def visit_StructPointerElementSuffix(self, node, visited_children):
separator, element = visited_children
assert separator.type == '->'
return lambda base: StructPointerElementExpression(base, element.data)
def visit_CommasExpressionList(self, node, visited_children):
first, rest, comma = visited_children
result = [first]
for comma, next in rest:
result.append(next)
return result
def visit_FunctionCallSuffix(self, node, visited_children):
lparen, args, rparen = visited_children
assert lparen.type == '('
assert rparen.type == ')'
if args is None:
args = []
return lambda base: FunctionCallExpression(base, args)
def visit_ArrayIndexSuffix(self, node, visited_children):
lbracket, index, rbracket = visited_children
assert lbracket.type == '['
assert rbracket.type == ']'
return lambda base: ArrayIndexExpression(base, index)
def visit_ObjectExpression(self, node, visited_children) -> Expression:
if isinstance(visited_children, list):
base, suffix = visited_children[0]
if len(suffix) > 0:
for suffix in suffix:
base = suffix(base)
return base
raise NotImplementedError('array/struct literals')
def visit_NegativeExpression(self, node, visited_children):
minus, body = visited_children
assert minus.type == '-'
return NegativeExpression(body)
def visit_AddressOfExpression(self, node, visited_children):
ampersand, body = visited_children
assert ampersand.type == '&'
return AddressOfExpression(body)
def visit_LogicalNotExpression(self, node, visited_children):
bang, body = visited_children
assert bang.type == '!'
return LogicalNotExpression(body)
def visit_SizeofExpression(self, node, visited_children):
sizeof, argument = visited_children[0]
assert sizeof.type == 'sizeof'
return SizeofExpression(argument)
def visit_TermExpression(self, node, visited_children) -> Expression:
base, suffix = visited_children
if suffix is not None:
for op, factor in suffix:
if op.type == '*':
base = MultiplyExpression(base, factor)
else:
raise NotImplementedError('term suffix ' + op)
return base
def visit_ArithmeticExpression(self, node, visited_children) -> Expression:
base, suffix = visited_children
if suffix is not None:
for op, term in suffix:
if op.type == '+':
base = AddExpression(base, term)
elif op.type == '-':
base = SubtractExpression(base, term)
else:
raise NotImplementedError('arithmetic suffix ' + op)
return base
def visit_GreaterEqExpression(self, node, visited_children):
value1, op, value2 = visited_children
assert op.type == '>='
return ComparisonExpression(value1, '>=', value2)
def visit_LessEqExpression(self, node, visited_children):
value1, op, value2 = visited_children
assert op.type == '<='
return ComparisonExpression(value1, '<=', value2)
def generic_visit(self, node, visited_children):
if isinstance(node.expr, TokenMatcher):
return node.text[0]
if isinstance(node.expr, OneOf):
return visited_children[0]
if isinstance(node.expr, Optional):
if len(visited_children) == 0:
return None
return visited_children[0]
if isinstance(node.expr, Sequence) and node.expr.name != '':
raise NotImplementedError('visit for sequence ' + str(node.expr))
if isinstance(node.expr, Compound):
return visited_children
print(node.expr)
return super(ASTBuilder, self).generic_visit(node, visited_children)
def build_ast(parse_tree, include_dirs):
builder = ASTBuilder(include_dirs)
return builder.visit(parse_tree)