~nch/python-compiler

80aad7744fe5c63828b96d53ab8c7a46b1df6d1b — nc 1 year, 6 months ago 8433cb3
more cutesy syntax and support for jne
2 files changed, 36 insertions(+), 12 deletions(-)

M compiler.py
M test_compiler.py
M compiler.py => compiler.py +14 -6
@@ 6,6 6,7 @@ from dataclasses import dataclass
from functools import singledispatch
from operator import itemgetter
from typing import Callable, Tuple, List, Union
import itertools

class NilToken(str):
    '''


@@ 688,13 689,12 @@ def pass2(instructions: List[Union[bytes, Callable, tuple]]) -> bytes:
    '''
    r = b''
    for i in instructions:
        if isinstance(i, tuple):
            for x in i:
                r += x(len(r)) if callable(x) else x
        if callable(i):
            r += i(len(r))
        elif isinstance(i, bytes):
            r += i
        else:
            assert False
            assert False, f'{type(i)} not handled'
    return r

# a full opcode list can be found here: http://ref.x86asm.net/coder64.html


@@ 703,11 703,14 @@ ops = [
        (('ret',), lambda _: b'\xc3'),

        # FIXME: calculate distance and emit correct opcode based on whether it's a short or long jump
        (('j', label_p), lambda _, l: (b'\xeb', lambda x: pack8(compute_offset(l)(x) - 1, signed=True))),
        (('j', label_p), lambda _, l: lambda x: b'\xeb' + pack8(compute_offset(l)(x) - 2, signed=True)),
        (('jne', label_p), lambda _, l: lambda x: b'\x75' + pack8(compute_offset(l)(x) + 2, signed=True)),

        ## COMPARISONS
        (('cmp', reg32_p, or_p(reg32_p, mem_p(reg32_p))), lambda _, r1, x: b'\x39' + modrm(r1, x)),
        (('cmp', reg64_p, or_p(reg64_p, mem_p(reg64_p))), lambda _, r1, x: b'\x48\x39' + modrm(r1, x)),
        (('cmp', 'eax', imm32_p), lambda _1, _2, x: b'\x3d' + pack32(x)),
        (('cmp', 'rax', imm32_p), lambda _1, _2, x: b'\x48\x3d' + pack32(x)),

        ## MOVS
        # reg -> reg moves get encoded with 0x89 because this is what NASM does. NASM gets used for testing


@@ 734,7 737,12 @@ def emit(*args):
    >>> emit('rax <- rcx') == b'\\x48\\x89\\xc8' # cutesy syntax
    True
    '''
    args = sum(map(lambda x: str.split(x) if type(x) == str else [x], args), []) # allow cutsey syntax
    def maybe_int(x):
        try:
            return int(x)
        except:
            return x
    args = list(itertools.chain(*map(lambda x: map(maybe_int, str.split(x)) if type(x) == str else [x], args))) # allow cutsey syntax
    for op, encoder_f in ops:
        if len(op) != len(args):
            continue

M test_compiler.py => test_compiler.py +22 -6
@@ 296,23 296,39 @@ on some lines''')
        # cmps
        self.assertEqual(emit('cmp eax ecx'), self.nasm_assemble(b'cmp ecx, eax'))
        self.assertEqual(emit('cmp rax rcx'), self.nasm_assemble(b'cmp rcx, rax'))
        self.assertEqual(emit('cmp eax', 3000), self.nasm_assemble(b'cmp eax, 3000'))
        self.assertEqual(emit('cmp rax', 3000), self.nasm_assemble(b'cmp rax, 3000'))

        # arithmetic
        self.assertEqual(emit('add eax ecx'), self.nasm_assemble(b'add eax, ecx'))

    def test_labels(self):
        r = []
        emit_label('l', r)
        r.append(emit('rax <-', 1))
        r.append(emit('add eax eax'))
        r.append(emit('j l'))
        self.assertEqual(pass2(r), self.nasm_assemble(b'''
        r1 = []
        emit_label('l', r1)
        r1.append(emit('rax <- 1'))
        r1.append(emit('add eax eax'))
        r1.append(emit('j l'))
        self.assertEqual(pass2(r1), self.nasm_assemble(b'''
            l:
            mov eax, 1
            add eax, eax
            jmp l
        '''))

        r2 = []
        r2.append(emit('rax <- 1'))
        emit_label('l', r2)
        r2.append(emit('add eax eax'))
        r2.append(emit('cmp rax 2000'))
        r2.append(emit('jne l'))
        self.assertEqual(pass2(r2), self.nasm_assemble(b'''
            mov eax, 1
            l:
            add eax, eax
            cmp rax, 2000
            jne l
            '''))

    '''
    def test_codegen(self):
        t = If(FunctionCall(FunctionCall('b', ()), 2), Return(1), Return(2))