~nch/python-compiler

65f5fe2532efb3c91b750a30b5d37a518528bf4d — nc 2 months ago 83fbefe master
add basic support for function calls
2 files changed, 72 insertions(+), 54 deletions(-)

M compiler.py
M test_compiler.py
M compiler.py => compiler.py +43 -35
@@ 544,8 544,8 @@ def pack8(imm, signed=False):
def pack16(imm):
    return struct.pack('H', imm)

def pack32(imm):
    return struct.pack('<L', imm)
def pack32(imm, signed=False):
    return struct.pack('<l' if signed else '<L', imm)

def pack64(imm):
    return struct.pack('<Q', imm)


@@ 568,6 568,7 @@ def mem_p(reg_p):
    return lambda x: type(x) in (list, tuple) and reg_p(x[0])
def or_p(a, b): return lambda x: a(x) or b(x)
def imm32_p(x): return type(x) == int
def imm16_p(x): return type(x) == int and 0 <= x <= 2**16 - 1
def imm8_p(x): return type(x) == int and 0 <= x <= 255

_modrm_pattern = {'*': 0b00, '*+disp8': 0b01, '*+disp32': 0b10, 'direct': 0b11}


@@ 692,11 693,11 @@ See https://wiki.osdev.org/X86-64_Instruction_Encoding for more information.
regs64 = {r: i for i, r in enumerate('rax rcx rdx rbx rsp rbp rsi rdi r8 r9 r10 r11 r12 r13 r14 r15'.split())}
regs32 = {r: i for i, r in enumerate('eax ecx edx ebx esp ebp esi edi r8d r9d r10d r11d r12d r13d r14d r15d'.split())}

def compute_offset(label):
def compute_offset(label, opcode_size):
    def f(loc):
        if label not in labels:
            raise Exception('undefined label', label)
        return labels[label] - loc - 2
        return labels[label] - loc - opcode_size
    return f

def pass2(instructions: List[Union[bytes, tuple]]) -> bytes:


@@ 709,7 710,7 @@ def pass2(instructions: List[Union[bytes, tuple]]) -> bytes:
        if isinstance(i, tuple):
            bytelen, f = i
            b = f(len(r))
            assert(len(b) == bytelen) # sanity check
            assert len(b) == bytelen, len(b) # sanity check
            r += b
        elif isinstance(i, bytes):
            r += i


@@ 717,6 718,7 @@ def pass2(instructions: List[Union[bytes, tuple]]) -> bytes:
            assert False, f'{type(i)} not handled'
    return r

# TODO: add segments so we can disassemble things...
def write_elf(text_section: bytes) -> bytes:
    '''
    for more info see:


@@ 725,17 727,17 @@ def write_elf(text_section: bytes) -> bytes:
    * `man elf`
    '''

    entry_vaddr = 0x401000
    entry_vaddr = 0x400000
    #                          magic        class  data           version  abi version  padding
    #                                       elf64  little endian  1
    ident = struct.pack('16b', *b'\x7fELF', 2,     1,             1,       0,           *([0] * 8))
    ehsize = 64
    phentsize = 56
    ehsize = 64 # size of elf header
    phentsize = 56 # program header entry size
    fsize = ehsize + phentsize + len(text_section)
    #             ident   type       machine     version    entry
    #                     exec       x86_64      1
    elf_header = [ident, pack16(2), pack16(62), pack32(1), pack64(entry_vaddr + ehsize + phentsize)]
    #              phoff                          shoff
    #              phoff           shoff
    elf_header += [pack64(ehsize), pack64(0)]
    #              flags       ehsize                 phentsize
    #              none        size of this header


@@ 757,11 759,14 @@ def write_elf(text_section: bytes) -> bytes:
ops = [
        ## CONTROL FLOW
        (('ret',), lambda _: b'\xc3'),
        (('ret', imm16_p), lambda _, x: b'\xc2' + pack16(x)),

        # FIXME: calculate distance and emit correct opcode based on whether it's a short or long jump
        (('j', label_p), lambda _, l: (2, lambda x: b'\xeb' + pack8(compute_offset(l)(x), signed=True))),
        (('je', label_p), lambda _, l: (2, lambda x: b'\x74' + pack8(compute_offset(l)(x), signed=True))),
        (('jne', label_p), lambda _, l: (2, lambda x: b'\x75' + pack8(compute_offset(l)(x), signed=True))),
        (('j', label_p), lambda _, l: (2, lambda x: b'\xeb' + pack8(compute_offset(l, 2)(x), signed=True))),
        (('je', label_p), lambda _, l: (2, lambda x: b'\x74' + pack8(compute_offset(l, 2)(x), signed=True))),
        (('jne', label_p), lambda _, l: (2, lambda x: b'\x75' + pack8(compute_offset(l, 2)(x), signed=True))),

        (('call', label_p), lambda _, l: (5, lambda x: b'\xe8' + pack32(compute_offset(l, 5)(x), signed=True))),

        ## COMPARISONS
        (('cmp', reg32_p, or_p(reg32_p, mem_p(reg32_p))), lambda _, r1, x: b'\x39' + modrm(r1, x)),


@@ 774,14 779,14 @@ ops = [
        # so I did it to be consistent, but someone did a bit more of an analysis here:
        # http://0x5a4d.blogspot.com/2009/12/on-moving-register.html

        ((or_p(reg32_p, mem_p(reg32_p)), '<-', reg32_p), lambda x, _, r1: b'\x89' + modrm(r1, x)),
        (('mov', or_p(reg32_p, mem_p(reg32_p)), reg32_p), lambda  _, x, r1: b'\x89' + modrm(r1, x)),
        # 0x67 prefix for 32-bit address override (see
        ((reg32_p, '<-', mem_p(reg32_p)), lambda r1, _, x: b'\x67\x8b' + modrm(r1, x)),
        (('mov', reg32_p, mem_p(reg32_p)), lambda _, r1, x: b'\x67\x8b' + modrm(r1, x)),
        # 64 bit movs have a 0x48 prefix to specify 64-bit registers
        ((or_p(reg64_p, mem_p(reg64_p)), '<-', reg64_p), lambda x, _, r1: b'\x48\x89' + modrm(r1, x)),
        ((reg64_p, '<-', mem_p(reg64_p)), lambda r1, _, x: b'\x48\x8b' + modrm(r1, x)),
        (('mov', or_p(reg64_p, mem_p(reg64_p)), reg64_p), lambda _, x, r1: b'\x48\x89' + modrm(r1, x)),
        (('mov', reg64_p, mem_p(reg64_p)), lambda _, r1, x: b'\x48\x8b' + modrm(r1, x)),

        ((general_purpose_reg64, '<-', imm32_p), lambda r, _, i: pack8(ord(b'\xb8') + regs64[r]) + pack32(i)),
        (('mov', general_purpose_reg64, imm32_p), lambda _, r, i: pack8(ord(b'\xb8') + regs64[r]) + pack32(i)),

        ## ARITHMETIC
        (('add', reg32_p, reg32_p), lambda _, r1, r2: b'\x01' + modrm(r2, r1)),


@@ 792,10 797,13 @@ ops = [
]

def emit(*args):
    if hasattr(emit, 'debug') and emit.debug: # HACK -- remove eventually
        print(args)

    '''
    >>> emit('rax', '<-', 'rcx') == b'\\x48\\x89\\xc8' # binary for mov rcx, rax
    >>> emit('mov', 'rax', 'rcx') == b'\\x48\\x89\\xc8' # binary for mov rcx, rax
    True
    >>> emit('rax <- rcx') == b'\\x48\\x89\\xc8' # cutesy syntax
    >>> emit('mov rax rcx') == b'\\x48\\x89\\xc8' # cutesy syntax
    True
    '''
    def maybe_int(x):


@@ 824,8 832,8 @@ def code_gen_module(fs: List, g):
        code_gen(f.body, [f.params], g)

        if f.name == 'main':
            postlude = emit('rdi <- rax') +\
                       emit('rax <- 60') +\
            postlude = emit('mov rdi rax') +\
                       emit('mov rax 60') +\
                       emit('syscall')
            g.append(postlude)



@@ 839,13 847,15 @@ def code_gen(node, ctx: List[Dict[str, int]], g):

@code_gen.register(int) # type: ignore
def _(i: int, ctx, g): # type: ignore
    g.append(emit('rax', '<-', i))
    g.append(emit('mov', 'rax', i))

def asmlen(xs: List[Union[bytes, tuple]]) -> int:
    return sum(x[0] if isinstance(x, tuple) else len(x) for x in xs)

labels = {}
def emit_label(name, g):
    if hasattr(emit, 'debug') and emit.debug: # HACK -- remove eventually
        print(name + ':')
    labels[name] = asmlen(g)

@code_gen.register(If) # type: ignore


@@ 876,40 886,38 @@ def _(block, ctx, g):
@code_gen.register(Assign) # type: ignore
def _(assign: Assign, ctx, g): # type: ignore
    code_gen(assign.rhs, ctx, g)
    g.append(emit(assign.lhs, '<-', 'rax'))
    g.append(emit('mov', assign.lhs, 'rax'))

call_registers = ['rdi', 'rsi', 'rdx', 'rcx', 'r8', 'r9']

R = TypeVar('R')
def traverse_tree(tree: Any, f: Callable[[R, Any], R], r: R, key: Callable=lambda x: x) -> R:
    r = f(r, key(tree))
    if isinstance(tree, Iterable):
    if isinstance(tree, Iterable) and not isinstance(tree, str):
        for subtree in tree:
            r = traverse_tree(subtree, f, r, key)
    return r

@code_gen.register(Return) # type: ignore
def _(ret, ctx, g):
    emit('rax', '<-', ret.value)
    emit('ret', len(ctx[-1]) * 8)
    g.append(emit('mov', 'rax', ret.value))
    g.append(emit('ret', len(ctx[-1]) * 8))

'''
@code_gen.register(FunctionCall)
def _(f: FunctionCall, ctx, g):
@code_gen.register(FunctionCall) # type: ignore
def _(f, ctx, g):
    # TODO: remove this limit and use stack
    assert len(f.args) < len(g.call_registers), 'too many arguments'
    assert len(f.args) < len(call_registers), 'too many arguments'

    # spill used registers
    for r in g.call_registers[:len(f.args)]: g.push(r)
    for r in call_registers[:len(f.args)]: g.push(r)

    for r, a in zip(g.call_registers, f.args):
    for r, a in zip(call_registers, f.args):
        g.mov(a, r)

    g.call(f.name)
    g.append(emit('call', f.name))

    # restore spilled registers
    for r in g.call_registers[:len(f.args)][::-1]: g.pop(r)
'''
    for r in call_registers[:len(f.args)][::-1]: g.pop(r)

if __name__ == "__main__":
    import doctest

M test_compiler.py => test_compiler.py +29 -19
@@ 304,15 304,17 @@ on some lines''')
            return f.read()

    def test_emit(self):
        # TODO: use property based testing to dynamically generate assembly and compare with nasm output

        # moves
        self.assertEqual(emit('rax <- rcx'), self.nasm_assemble(b'mov rax, rcx'))
        self.assertEqual(emit('rax <-', 8),  self.nasm_assemble(b'mov rax, 8'))
        self.assertEqual(emit('rax <-', ('rcx', 8)), self.nasm_assemble(b'mov rax, [rcx+8]'))
        self.assertEqual(emit('eax <-', ('ecx', 8)), self.nasm_assemble(b'mov eax, [ecx+8]'))
        self.assertEqual(emit('rax <- rcx'), b'\x48\x89\xc8')
        self.assertEqual(emit('eax <- ecx'), b'\x89\xc8')
        self.assertEqual(emit('rcx <- rax'), b'\x48\x89\xc1')
        self.assertEqual(emit('ecx <- eax'), b'\x89\xc1')
        self.assertEqual(emit('mov rax rcx'), self.nasm_assemble(b'mov rax, rcx'))
        self.assertEqual(emit('mov rax', 8),  self.nasm_assemble(b'mov rax, 8'))
        self.assertEqual(emit('mov rax', ('rcx', 8)), self.nasm_assemble(b'mov rax, [rcx+8]'))
        self.assertEqual(emit('mov eax', ('ecx', 8)), self.nasm_assemble(b'mov eax, [ecx+8]'))
        self.assertEqual(emit('mov rax rcx'), b'\x48\x89\xc8')
        self.assertEqual(emit('mov eax ecx'), b'\x89\xc8')
        self.assertEqual(emit('mov rcx rax'), b'\x48\x89\xc1')
        self.assertEqual(emit('mov ecx eax'), b'\x89\xc1')

        # cmps
        self.assertEqual(emit('cmp eax ecx'), self.nasm_assemble(b'cmp ecx, eax'))


@@ 328,7 330,7 @@ on some lines''')
        r1.append(emit('j l'))
        r1.append(emit('j l'))
        emit_label('l', r1)
        r1.append(emit('rax <- 1'))
        r1.append(emit('mov rax 1'))
        r1.append(emit('add eax eax'))
        r1.append(emit('j l'))
        self.assertEqual(pass2(r1), self.nasm_assemble(b'''


@@ 341,7 343,7 @@ on some lines''')
        '''))

        r2 = []
        r2.append(emit('rax <- 1'))
        r2.append(emit('mov rax 1'))
        r2.append(emit('jne l'))
        r2.append(emit('jne l'))
        emit_label('l', r2)


@@ 361,10 363,26 @@ on some lines''')
            '''))

    def test_codegen(self):
        '''
        g = []
        code_gen_module(
                [FunctionDef('main', [], Block())],
                g)
        '''

        g = []
        code_gen_module(
                [FunctionDef('main', [], Block(FunctionCall('f'))),
                 FunctionDef('f', [], Return (2))],
                g)
        self.assertEqual(self.execute_program(pass2(g)), 2)

    def execute_program(self, asm_bytes):
        with open('a.bin', 'wb') as f:
            f.write(write_elf(asm_bytes))
        os.system('chmod +x a.bin')
        r = os.system('./a.bin')
        return r >> 8

    def test_compiler(self):
        main = normalize_stmt(parse(textwrap.dedent(


@@ 377,15 395,7 @@ on some lines''')

        g = []
        code_gen_module([FunctionDef('main', [], main)], g)

        postlude = emit('rdi <- rax') +\
                   emit('rax <- 60') +\
                   emit('syscall')
        with open('a.bin', 'wb') as f:
            f.write(write_elf(pass2(g) + postlude))
        os.system("chmod +x a.bin")
        r = os.system("./a.bin")
        self.assertEqual(7, r >> 8)
        self.assertEqual(7, self.execute_program(pass2(g)))

import doctest
def load_tests(loader, tests, ignore):