~nch/python-compiler

511ec7acdfbd89498ed561c36bf0d5cbe608a41e — nc 1 year, 11 months ago c6d2b36
add anf transform for if statement
2 files changed, 55 insertions(+), 34 deletions(-)

M compiler.py
M test_compiler.py
M compiler.py => compiler.py +45 -27
@@ 50,18 50,11 @@ class Stream:
    _stream: str
    i: int = 0
    indent: int = 0
    @property
    def stream(self): return self._stream[self.i:]

    @property
    def row(self): return sum((1 for c in self._stream[:self.i] if c == '\n'))

    @property
    def col(self): return next((j for j in range(self.i) if self._stream[self.i - j] == '\n'), self.i)

    stream = property(lambda self: self._stream[self.i:])
    row = property(lambda self: sum((1 for c in self._stream[:self.i] if c == '\n')))
    col = property(lambda self: next((j for j in range(self.i) if self._stream[self.i - j] == '\n'), self.i))
    def empty(self): return self.i >= len(self._stream)


@dataclass
class ParseError:
    stream: Stream # TODO: add an index, use a slice instead of updating `stream` so we can peek backwards for error messages


@@ 356,17 349,23 @@ class Block(list):
        return self

class FunctionCall(tuple):
    @property
    def name(self): return self[0]
    @property
    def args(self): return self[1:]
    def __new__(cls, name, *args):
        return super(FunctionCall, cls).__new__(cls, tuple((name, *args)))
    name = property(lambda s: s[0])
    args = property(lambda s: s[1:])

class Assign(tuple):
    def __new__(cls, a, b): return super(Assign, cls).__new__(cls, tuple(('=', a, b)))
    @property
    def lhs(self): return self[0]
    @property
    def rhs(self): return self[1]
    def __new__(cls, a, b):
        return super(Assign, cls).__new__(cls, tuple(('=', a, b)))
    lhs = property(lambda self: self[1])
    rhs = property(lambda self: self[2])

class If(tuple):
    def __new__(cls, cond, then, otherwise):
        return super(If, cls).__new__(cls, tuple(('if', cond, then, otherwise)))
    cond = property(lambda s: s[1])
    then = property(lambda s: s[2])
    otherwise = property(lambda s: s[3])

def intersperse(p, delimp):
    """


@@ 440,16 439,40 @@ def gensym():
def is_trivial(b):
    return type(b) in {str, int, float}


@singledispatch
def normalize_stmt(stmt): # expr case
    norm, hoisted = normalize_expr(stmt)
    return hoisted.add(norm)

@normalize_stmt.register(Block)
def _(block):
    norm_block = Block()
    for b in block:
        norm_block.add(normalize_stmt(b))
    return norm_block

@normalize_stmt.register(If)
def _(ifnode: If):
    norm_cond, hoisted = normalize_expr(ifnode.cond)
    norm_then = normalize_stmt(ifnode.then)
    norm_otherwise = normalize_stmt(ifnode.otherwise)
    norm_if = If(maybe_hoist(norm_cond, hoisted), normalize_stmt(ifnode.then), normalize_stmt(ifnode.otherwise))
    if hoisted:
        return hoisted.add(norm_if)
    else:
        return norm_if

@singledispatch
def normalize_expr(node):
    return node, Block()

def maybe_hoist(expr, hoisted):
    if is_trivial(expr):
        return expr
    new_var = gensym()
    hoisted.add(Assign(new_var, expr))
    return new_var

@normalize_expr.register(FunctionCall)
def _(node: FunctionCall):
    new_args = []


@@ 457,13 480,8 @@ def _(node: FunctionCall):
    for arg in node.args:
        a, h_ = normalize_expr(arg)
        hoisted.add(h_)
        if is_trivial(a):
            new_args.append(a)
        else:
            new_var = gensym()
            hoisted.add(Assign(new_var, a))
            new_args.append(new_var)
    return FunctionCall([node.name] + new_args), hoisted
        new_args.append(maybe_hoist(a, hoisted))
    return FunctionCall(node.name, *new_args), hoisted

### Code gen


M test_compiler.py => test_compiler.py +10 -7
@@ 68,24 68,27 @@ else:
        self.assertEqual(b2, ['a', 'b', 'c'])

    def test_normalize(self):
        tree1 = FunctionCall(('a', 'b', 'c'))
        tree1 = FunctionCall('a', 'b', 'c')
        self.assertEqual(normalize_expr(tree1)[0], ('a', 'b', 'c'))

        global gensym_counter
        gensym_counter = 0

        tree2 = FunctionCall(('a', FunctionCall(('b')), 'c'))
        tree2 = FunctionCall('a', FunctionCall('b'), 'c')
        self.assertEqual(normalize_stmt(tree2),
                [('=', 'tmp1', ('b',)),
                 ('a', 'tmp1', 'c')])

        gensym_counter = 0
        tree3 = FunctionCall(('a', FunctionCall(('b', FunctionCall(('x',)))), 'c'))
        tree3 = FunctionCall('a', FunctionCall('b', FunctionCall('x')), 'c')
        self.assertEqual(normalize_stmt(tree3),
                [('=', 'tmp2', ('x',)),
                 ('=', 'tmp3', ('b', 'tmp2')),
                 ('a', 'tmp3', 'c')])

        tree4 = If(FunctionCall('b', FunctionCall('x')),
                   'c',
                   FunctionCall('d'))
        self.assertEqual(normalize_stmt(tree4),
                [('=', 'tmp4', ('x',)),
                 ('=', 'tmp5', ('b', 'tmp4')),
                 ('if', 'tmp5', ['c'], [('d',)])])

if __name__ == '__main__':
    unittest.main()