~nch/python-compiler

e7a4a935d2ac1355c21668b317ffb4a508cf0ca0 — nc 1 year, 7 months ago 95f1dbe
add initial support for labels and jumps to the assembler and some more work on indentation
2 files changed, 79 insertions(+), 15 deletions(-)

M compiler.py
M test_compiler.py
M compiler.py => compiler.py +48 -14
@@ 5,7 5,7 @@
from dataclasses import dataclass
from functools import singledispatch
from operator import itemgetter
from typing import Callable, Tuple
from typing import Callable, Tuple, List, Union

class NilToken(str):
    '''


@@ 406,15 406,13 @@ def indentation(expect='same'):
            actual = 'indent'
        elif len(c) < stream.indent:
            actual = 'dedent'
        else:
            assert False

        if expect != actual:
            return ParseError(stream, f'indentation level: {expect}', f'indentation level: {actual}'), stream
        else:
        if expect == actual:
            new_stream.indent = len(c)
            return EMPTY, new_stream

        return ParseError(stream, f'indentation level: {expect} (from {stream.indent})', f'indentation level: {actual} ({len(c)})'), stream

    return indentationf

# we don't care about space, so we discard it


@@ 441,13 439,14 @@ return_stmt_body: ParserT = seq('return', space, expr)

if_stmt: ParserT = seq('if', space, expr, discard(char(':')),
        lambda x: block(x),
        'else', discard(char(':')),
        indentation('same'), 'else', discard(char(':')),
        lambda x: block(x))

# <stmt> := (<return-stmt-body> | <assign-stmt-body | <if-stmt> | <expr>) <newline>
stmt: ParserT = convert(seq(oneof(return_stmt_body, assign_stmt_body, expr), newline), lambda x: x[0])
# <block> := <newline> (<indentation> <stmt>)+
# <stmt> := (<if-stmt> | <return-stmt-body> | <assign-stmt-body | <expr>) <newline>
stmt: ParserT = convert(seq(oneof(if_stmt, return_stmt_body, assign_stmt_body, expr), newline), lambda x: x[0])
# <block> := <newline> (<indent> <stmt>)+
block: ParserT = convert(seq(newline, one_or_more(convert(seq(indentation('indent'), stmt), lambda x: x[0]))), lambda x: x[0])
#import tmp; block = tmp.trace_function(block)
# <function> := 'def' <space> <identifier> '(' (<identifier> (',' <space> <identifier>)*) ')' ':' <newline> <block>
function: ParserT = seq('def', space, identifier, char('('), intersperse(identifier, discard(seq(char(','), space))), char(')'), char(':'), block)



@@ 527,13 526,15 @@ def _(function_call):
# TODO: register allocation -- go through the tree and tag bindings with registers or memory locations
import struct

def pack8(imm):
    return struct.pack('B', imm)
def pack8(imm, signed=False):
    return struct.pack('b' if signed else 'B', imm)

def pack32(imm):
    return struct.pack('<L', imm)

def reg64_p(x): return x in regs64
def reg32_p(x): return x in regs32
def label_p(x): return type(x) == str

def get_reg_p(reg):
    if reg64_p(reg): return reg64_p


@@ 673,10 674,36 @@ See https://wiki.osdev.org/X86-64_Instruction_Encoding for more information.
regs64 = {r: i for i, r in enumerate('rax rcx rdx rbx rsp rbp rsi rdi r8 r9 r10 r11 r12 r13 r14 r15'.split())}
regs32 = {r: i for i, r in enumerate('eax ecx edx ebx esp ebp esi edi r8d r9d r10d r11d r12d r13d r14d r15d'.split())}

def compute_offset(label):
    def f(loc):
        if label not in labels:
            raise Exception('undefined label', label)
        return labels[label] - loc
    return f

def pass2(instructions: List[Union[bytes, Callable, tuple]]) -> bytes:
    '''
    The second pass of the assembler that fills in jump locations and combines the final
    instruction bytestream
    '''
    r = b''
    for i in instructions:
        if type(i) == tuple:
            for x in i:
                r += x(len(r)) if callable(x) else x
        elif type(i) == bytes:
            r += i
        else:
            assert False
    return r

# a full opcode list can be found here: http://ref.x86asm.net/coder64.html
ops = [
        ## CONTROL FLOW
        (('ret',), lambda _: '\xc3'),
        (('ret',), lambda _: b'\xc3'),

        # FIXME: calculate distance and emit correct opcode based on whether it's a short or long jump
        (('j', label_p), lambda _, l: (b'\xeb', lambda x: pack8(compute_offset(l)(x) - 1, signed=True))),

        ## COMPARISONS
        (('cmp', reg32_p, or_p(reg32_p, mem_p(reg32_p))), lambda _, r1, x: b'\x39' + modrm(r1, x)),


@@ 695,6 722,9 @@ ops = [
        ((reg64_p, '<-', mem_p(reg64_p)), lambda r1, _, x: b'\x48\x8b' + modrm(r1, x)),

        ((general_purpose_reg64, '<-', imm32_p), lambda r, _, i: pack8(ord(b'\xb8') + regs64[r]) + pack32(i)),

        ## ARITHMETIC
        (('add', reg32_p, reg32_p), lambda _, r1, r2: b'\x01' + modrm(r2, r1))
]

def emit(*args):


@@ 726,6 756,10 @@ def code_gen(node, ctx, g):
def _(i: int, ctx, g): # type: ignore
    g.append(emit('rax', '<-', i))

labels = {}
def emit_label(name, g):
    labels[name] = len(g)

@code_gen.register(If) # type: ignore
def _(ifstmt, ctx, g):
    false_label = gensym('false')


@@ 738,7 772,7 @@ def _(ifstmt, ctx, g):
    g.append(emit('j', end_label))

    # false branch:
    g.append(emit_label(false_label))
    emit_label(false_label, g)
    code_gen(ifstmt.otherwise, ctx, g)

    g.append(emit_label(end_label))

M test_compiler.py => test_compiler.py +31 -1
@@ 124,6 124,7 @@ on some lines''')
        self.assertEqual(parse('l33t c0d3r', identifier), 'l33t')
        self.assertEqual(parse('val+34', identifier), 'val')

    """
    def test_parse_block(self):
        self.assertEqual(parse(textwrap.dedent(
            '''\


@@ 131,7 132,14 @@ on some lines''')
                b
            c'''), seq(identifier, block)), ['a', ['b']])

        print('--')
        self.assertEqual(parse('\n  if true:\n  a(2)\n  else:\n  a(1)\n', block),
                ['if', 'true', [['call', 'a', [2]]], 'else', [['call', 'a', [1]]]])
    """


    def test_if(self):
        """
        self.assertEqual(parse(textwrap.dedent(
                        '''\
                        if true:


@@ 143,15 151,21 @@ on some lines''')
                        blah_blah()
                        '''), if_stmt),
                        ['if', 'true', [['call', 'print', [2]], ['call', 'print', [3]]], 'else', [['return', 1], ['return', 2]]])

        self.assertEqual(parse(textwrap.dedent(
                        '''\
                        if true:
                            print(1)
                            print(2)
                            if true:
                                print(2)
                            else:
                                return false
                            print(2)
                        else:
                            return 1
                        '''), if_stmt),
                        ['if', 'true', [['call', 'print', [1]], ['call', 'print', [2]]], 'else', [['return', 1]]])
        """

    def test_block(self):
        b1 = Block('a', 'b', 'c')


@@ 283,6 297,22 @@ on some lines''')
        self.assertEqual(emit('cmp eax ecx'), self.nasm_assemble(b'cmp ecx, eax'))
        self.assertEqual(emit('cmp rax rcx'), self.nasm_assemble(b'cmp rcx, rax'))

        # arithmetic
        self.assertEqual(emit('add eax ecx'), self.nasm_assemble(b'add eax, ecx'))

    def test_labels(self):
        r = []
        emit_label('l', r)
        r.append(emit('rax <-', 1))
        r.append(emit('add eax eax'))
        r.append(emit('j l'))
        self.assertEqual(pass2(r), self.nasm_assemble(b'''
            l:
            mov eax, 1
            add eax, eax
            jmp l
        '''))

    '''
    def test_codegen(self):
        t = If(FunctionCall(FunctionCall('b', ()), 2), Return(1), Return(2))