~nch/python-compiler

83fbefef0ee2a8ed0856b7bb731634cde8d7f7bf — nc 2 months ago cb5fb80
start adding infrastructure for functions
2 files changed, 77 insertions(+), 12 deletions(-)

M compiler.py
M test_compiler.py
M compiler.py => compiler.py +51 -7
@@ 3,9 3,11 @@
# set makeprg=python3\ ./compiler.py\ &&\ python3\ ./test_compiler.py

from dataclasses import dataclass
from functools import singledispatch
from functools import singledispatch, reduce
from operator import itemgetter
from typing import Callable, Tuple, List, Union
from typing import Callable, Tuple, List, Dict, Union, TypeVar, Any
from collections.abc import Iterable
import operator
import itertools

class NilToken(str):


@@ 383,8 385,10 @@ def nodeclass(name, fields, hole_values=[]):

FunctionCall = nodeclass('FunctionCall', 'name *args')
Assign = nodeclass('Assign', '_ lhs rhs', ['='])
Let = nodeclass('Let', '_ binding value', ['let'])
If = nodeclass('If', '_ cond then otherwise', ['if'])
Return = nodeclass('Return', '_ value', ['return'])
FunctionDef = nodeclass('FunctionDef', '_ name params body', ['def'])

def intersperse(p, delimp):
    """


@@ 450,12 454,15 @@ if_stmt: ParserT = convert(seq('if', space, expr, discard(char(':')),
stmt: ParserT = convert(seq(oneof(if_stmt, return_stmt_body, assign_stmt_body, expr), newline), lambda x: x[0])
# <block> := <newline> (<indent> <stmt>)+
block: ParserT = convert(seq(newline, one_or_more(convert(seq(indentation('indent'), stmt), lambda x: x[0]))), lambda x: Block(*x[0]))
#import tmp; block = tmp.trace_function(block)
# <function> := 'def' <space> <identifier> '(' (<identifier> (',' <space> <identifier>)*) ')' ':' <newline> <block>
function: ParserT = seq('def', space, identifier, char('('), intersperse(identifier, discard(seq(char(','), space))), char(')'), char(':'), block)

# TODO: I'm starting to feel that it's possible to handle whitespace with a lexer...

### A-normal form normalizer

# TODO: make sure to hoist all function definitions and wrap the main function

gensym_counter = 0
def gensym(prefix='tmp'):
    global gensym_counter


@@ 488,6 495,7 @@ def _(ifnode):
    else:
        return norm_if

# TODO: make sure all functions have a return value
@normalize_stmt.register(Return) # type: ignore
def _(ret):
    n, hoisted = normalize_expr(ret.value)


@@ 513,7 521,7 @@ def maybe_hoist(expr, hoisted):
    if is_trivial(expr):
        return expr
    new_var = gensym()
    hoisted.add(Assign(new_var, expr))
    hoisted.add(Let(new_var, expr))
    return new_var

@normalize_expr.register(FunctionCall) # type: ignore


@@ 634,6 642,7 @@ def modrm(reg1, reg2_or_mem):

    For info: https://wiki.osdev.org/X86-64_Instruction_Encoding#ModR.2FM_and_SIB_bytes
    '''
    # TODO: refactor this spaghetti mess
    reg_p = get_reg_p(reg1)
    regs = get_reg_t(reg1)
    if mem_p(reg_p)(reg2_or_mem):


@@ 776,6 785,7 @@ ops = [

        ## ARITHMETIC
        (('add', reg32_p, reg32_p), lambda _, r1, r2: b'\x01' + modrm(r2, r1)),
        (('add', reg64_p, imm32_p), lambda _, r1, im: b'\x48\x81' + _pack_modrm(0, regs64[r1], 'direct') + pack32(im)),

        (('syscall',), lambda _: b'\x0f\x05'),
        (('int', imm32_p), lambda _, x: b'\xcd' + pack8(x))


@@ 807,8 817,24 @@ def emit(*args):
            return encoder_f(*args)
    assert False, f"unknown opcode for {args}"

def code_gen_module(fs: List, g):
    for f in fs:
        assert isinstance(f, FunctionDef), ('mismatch', type(f), 'expected', FunctionDef)
        emit_label(f.name, g)
        code_gen(f.body, [f.params], g)

        if f.name == 'main':
            postlude = emit('rdi <- rax') +\
                       emit('rax <- 60') +\
                       emit('syscall')
            g.append(postlude)

@singledispatch
def code_gen(node, ctx, g):
def code_gen(node, ctx: List[Dict[str, int]], g):
    '''
    Generates the body (passed in `node`) of a function.
    Note that `node` *must* be a-normalized. TODO: encode this with types
    '''
    assert False, f'unhandled: {node} type {type(node)}'

@code_gen.register(int) # type: ignore


@@ 841,14 867,32 @@ def _(ifstmt, ctx, g):
    emit_label(end_label, g)

@code_gen.register(Block) # type: ignore
def _(block: Block, ctx, g): # type: ignore
    for stmt in block: code_gen(stmt, ctx, g)
def _(block, ctx, g):
    num_lets = traverse_tree(block, operator.add, 0, lambda x: 1 if isinstance(x, Let) else 0)
    emit('add', 'rsp', -num_lets * 8)
    for stmt in block:
        code_gen(stmt, ctx + [[]], g)

@code_gen.register(Assign) # type: ignore
def _(assign: Assign, ctx, g): # type: ignore
    code_gen(assign.rhs, ctx, g)
    g.append(emit(assign.lhs, '<-', 'rax'))

call_registers = ['rdi', 'rsi', 'rdx', 'rcx', 'r8', 'r9']

R = TypeVar('R')
def traverse_tree(tree: Any, f: Callable[[R, Any], R], r: R, key: Callable=lambda x: x) -> R:
    r = f(r, key(tree))
    if isinstance(tree, Iterable):
        for subtree in tree:
            r = traverse_tree(subtree, f, r, key)
    return r

@code_gen.register(Return) # type: ignore
def _(ret, ctx, g):
    emit('rax', '<-', ret.value)
    emit('ret', len(ctx[-1]) * 8)

'''
@code_gen.register(FunctionCall)
def _(f: FunctionCall, ctx, g):

M test_compiler.py => test_compiler.py +26 -5
@@ 76,6 76,18 @@ class TestCompiler(unittest.TestCase):
        self.assertEqual(self.unify([('a', Var('x'), 'c'), ('d', Var('x'), 'f')],
                                    [('a', 'b', 'c'), ('d', 'e', 'f')]), None)

    def test_traverse_tree(self):
        import operator
        self.assertEqual(traverse_tree([1, 2, [3, 4, [5], 6]],
                                        operator.add,
                                        0,
                                        lambda x: x if type(x) == int else 0), 21)
        self.assertEqual(traverse_tree([1, 2, [3, 4, [5], 6]],
                                        operator.add,
                                        0,
                                        lambda x: x if type(x) == int and x % 2 == 0 else 0),
                                        12)

    def test_stream(self):
        s = Stream('''Some
text


@@ 125,6 137,7 @@ on some lines''')
        self.assertEqual(parse('l33t c0d3r', identifier), 'l33t')
        self.assertEqual(parse('val+34', identifier), 'val')

    @unittest.skip("broken until parsing indentation works")
    def test_parse_block(self):
        r = parse(textwrap.dedent(
                    '''\


@@ 134,10 147,11 @@ on some lines''')
        self.assertEqual(r, ['a', ['b']])

        print('--')
        self.assertEqual(parse('\n  if true:\n   a(2)\n   else:\n   a(1)\n', block, debug=True),
        self.assertEqual(parse('\n  return 1', block, debug=True),
                ['if', 'true', [['call', 'a', [2]]], 'else', [['call', 'a', [1]]]])


    @unittest.skip("broken until parsing indentation works")
    def test_if(self):
        self.assertEqual(parse(textwrap.dedent(
                        '''\


@@ 346,16 360,23 @@ on some lines''')
            jne l
            '''))

    def test_compiler(self):
        t = If(1, 7, 4)
    def test_codegen(self):
        g = []
        code_gen(normalize_stmt(parse(textwrap.dedent(
        code_gen_module(
                [FunctionDef('main', [], Block())],
                g)

    def test_compiler(self):
        main = normalize_stmt(parse(textwrap.dedent(
        '''\
        if 1:
            7
        else:
            4
        '''), if_stmt)), None, g)
        '''), if_stmt))

        g = []
        code_gen_module([FunctionDef('main', [], main)], g)

        postlude = emit('rdi <- rax') +\
                   emit('rax <- 60') +\