~nch/python-compiler

4f4ed481f2e224fc4d759e22e2217857ce6fd08b — nc 1 year, 7 months ago ab1eb6c
start building an assembler
2 files changed, 183 insertions(+), 47 deletions(-)

M compiler.py
M test_compiler.py
M compiler.py => compiler.py +146 -46
@@ 4,7 4,6 @@

from dataclasses import dataclass
from functools import singledispatch
from collections.abc import Iterable
from operator import itemgetter

class NilToken(str):


@@ 514,85 513,182 @@ def _(node: FunctionCall):
    return FunctionCall(node.name, *new_args), hoisted

### Code gen (x86-64)

# TODO: register allocation -- go through the tree and tag bindings with registers or memory locations
import struct

def pack8(imm):
    return struct.pack('B', imm)
def pack32(imm):
    return struct.pack('<L', imm)

def reg_p(x): return x in regs
def small_reg_p(x): return reg_p(x) and regs[x] < 8
def mem_p(x): return list(x) or tuple(x)
def or_p(a, b): return lambda x: a(x) or b(x)
def imm32_p(x): return type(x) == int
def imm8_p(x): return type(x) == int and 0 <= x <= 255

_modrm_pattern = {'*': 0b00, '*+disp8': 0b01, '*+disp32': 0b10, 'direct': 0b11}
def _pack_modrm(reg_id, rm, mod):
    '''
    Construct a ModR/M byte

#import io
#code_gen_buffer = io.BytesIO()
    bit pattern:
    xx xxx xxx = mod (2 bits) | reg (3 bits) | rm (3 bits)
    '''
    return struct.pack('B', _modrm_pattern[mod] << 6 | reg_id << 3 | rm)

@dataclass
class Register:
    name: str
    num: int
def modrm(reg1, reg2_or_mem):
    '''
    Build a ModR/M byte sequence (used to encode register/memory arguments efficiently for a x86)

    Example 1.

            mov modrm('rax', 'rcx')                  # rcx = rax

        To emit an instruction to move the contents of rax into rcx, we want to encode the ModR/M byte
        using "direct" addressing (i.e. no memory offsets -- just copy one register directly into the other).
        We would encode this into a two byte sequence:

            [byte 1  (mov opcode)] '1011001'
            [byte 2 (ModR/M byte)] '11' (direct mod), '000' (rax), '001' (rcx)

        `[byte 2]` is the "argument" to the mov opcode. `mod` is kind of like a flag that changes the way `mov`
        can work. When set to '11' it uses direct mode.

    Example 2.

            mov modrm('rax', ['rcx'])                # rcx = *rax

        In this situation, we actually want to treat the value in rax as a pointer. Instead of rax's its value
        to rcx, we want to dereference it, and copy the value at the memory location to rcx. This is called
        'indirect' addressing, since we're doing pointer dereference (indirection) to access the value.

            [byte 1  (mov opcode)] '1011001'
            [byte 2 (ModR/M byte)] '00' (indirect mod), '000' (rax), '001' (rcx)

# FIXME: I don't like this class
#   I don't like that it holds both architectural details, as well as a buffer that is being written to...
#   TODO: remove this whole abstraction and just implement the assembly instructions inline, and append to an io buffer
class CodeGenBuffer:
    def __init__(self):
        self.buffer = []
        _reg_enc = 'rax rcx rdx rbx rsp rbp rsi rdi r8 r9 r10 r11 r12 r13 r14 r15'
        for i, r in enumerate(_reg_enc.split()):
            self.__dict__[r] = Register(r, i)
        self.call_registers = [self.rdi, self.rsi, self.rdx, self.rcx, self.r8, self.r9]
        The ModR/M byte is only 1 bit different than the previous example, but performs a very different
        function.

    def push(self, r):
        self.buffer.append(['push', r.name])
    Example 3.
        If we want to emit an instruction sequence to perform the following:

    def mov(self, a, b):
        self.buffer.append(['mov', a, b])
            mov modrm('rax', ['rdx', 4]) rdx = *(rax + 4)

    def call(self, f):
        self.buffer.append(['call', f])
        Which may be more recognizable in the following form:

    def assign(self, a, b):
        self.buffer.append(['assign', a, b])
            rdx = rax[4]

    def branch_if_false(self, label):
        self.buffer.append(['cmp', 'rax'])
        self.buffer.append(['bne', label])
        We want to take the memory address in rax, add 4 to it, and get the value there. We could emit a
        sequence of instructions to essentially perform:

    def branch(self, label):
        self.buffer.append(['ba', label])
            rcx = rax
            rcx += 4
            rdx = *rcx

    def label(self, label):
        self.buffer.append(label)
        However, dereferencing a register plus a known offset is such a common operation that having to emit
        3 instructions every time this came up would lead to inefficiency and code bloat. That's why the ModR/M
        byte has two more mod settings. When it's set to '01', it'll perform a 8-bit indirect dereference,
        and when set to '10' it'll perform a 32-bit indirect dereference. The only reason there are two

        In this case we want to emit the ModR/M byte, followed by the sequence of bytes encoding the offset.
        The function will emit:

            [byte 1  (mov opcode)] '1011001'
            [byte 2 (ModR/M byte)] '01' (indirect + disp8 mod), '000' (rax), '010' (rdx)
            [byte 3       (disp8)] '0000100' (offset)


    For info: https://wiki.osdev.org/X86-64_Instruction_Encoding#ModR.2FM_and_SIB_bytes
    '''
    if type(reg2_or_mem) == list:
        reg2 = reg2_or_mem[0]
        offset = 0 if len(reg2_or_mem) == 1 else reg2_or_mem[1]
        if not offset: # *reg2
            return _pack_modrm(regs[reg1], regs[reg2], 'indirect')
        elif imm8_p(offset): # *(reg2 + offset) when offset is small enough to be encoded in 1 byte
            return _pack_modrm(regs[reg1], regs[reg2], '*+disp8') + struct.pack('B', offset)
        elif imm32_p(offset): # *(reg2 + offset) when offset must be encoded in 4 bytes
            return _pack_modrm(regs[reg1], regs[reg2], '*+disp32') + struct.pack('<L', offset)
        assert False, f'Unknown indirect addressing mode {reg2_or_mem}'
    else:
        return _pack_modrm(regs[reg1], regs[reg2_or_mem], 'direct')

import io
codegen_buf = io.BytesIO()

# see https://wiki.osdev.org/X86-64_Instruction_Encoding#Registers
regs = {r: i for i, r in enumerate('rax rcx rdx rbx rsp rbp rsi rdi r8 r9 r10 r11 r12 r13 r14 r15'.split())}

# a full opcode list can be found here: http://ref.x86asm.net/coder64.html
ops = [
        (('ret',), lambda _: '\xc3'),
        # reg -> reg moves get encoded with 0x89 because this is what NASM does. NASM gets used for testing
        # so I did it to be consistent, but someone did more an analysis at some point:
        # http://0x5a4d.blogspot.com/2009/12/on-moving-register.html
        ((or_p(reg_p, mem_p), '<-', reg_p), lambda x, _, r1: b'\x89' + modrm(r1, x)),
        ((reg_p, '<-', mem_p), lambda r1, _, x: b'\x8b' + modrm(r1, x)),
        ((small_reg_p, '<-', imm32_p), lambda r, _, i: pack8(int('\xc7') + r) + pack32(i)),
]

def emit(*args):
    '''
    >>> emit('rax', '<-', 'rcx') == b'\\x89\\xc8' # binary for mov rcx, rax
    True
    >>> emit('rax <- rcx') == b'\\x89\\xc8' # cutesy syntax
    True
    '''
    args = sum(map(str.split, args), []) # allow cutsey syntax
    for op, encoder_f in ops:
        if len(op) != len(args):
            continue

        for o, v in zip(op, args):
            if (type(o) == str or type(o) == int) and o != v:
                break
            if type(o) == callable and not o(v):
                break
        else:
            return encoder_f(*args)
    assert False, f"unknown encoding for {args}"

@singledispatch
def code_gen(node, ctx, g: CodeGenBuffer):
def code_gen(node, ctx, g):
    assert False, f'unhandled: {node} type {type(node)}'

@code_gen.register(int)
def _(i: int, ctx, g: CodeGenBuffer):
    g.mov(i, 'rax')
def _(i: int, ctx, g):
    g.append(emit('rax', '<-', i))

@code_gen.register(If)
def _(ifstmt: If, ctx, g: CodeGenBuffer):
def _(ifstmt: If, ctx, g):
    false_label = gensym('false')
    end_label = gensym('end')
    g.mov(ifstmt.cond, g.rax)
    g.branch_if_false(false_label)
    code_gen(ifstmt.cond, ctx, g)
    g.append(emit('jne', false_label))

    # true branch:
    code_gen(ifstmt.then, ctx, g)
    g.branch(end_label)
    g.append(emit('j', end_label))

    # false branch:
    g.label(false_label)
    g.append(emit_label(false_label))
    code_gen(ifstmt.otherwise, ctx, g)

    g.label(end_label)
    g.append(emit_label(end_label))

@code_gen.register(Block)
def _(block: Block, ctx, g: CodeGenBuffer):
def _(block: Block, ctx, g):
    for stmt in block: code_gen(stmt, ctx, g)

@code_gen.register(Assign)
def _(assign: Assign, ctx, g: CodeGenBuffer):
    g.assign(assign.lhs, assign.rhs)
def _(assign: Assign, ctx, g):
    code_gen(assign.rhs, ctx, g)
    g.append(emit(assign.lhs, '<-', 'rax'))

'''
@code_gen.register(FunctionCall)
def _(f: FunctionCall, ctx, g: CodeGenBuffer):
def _(f: FunctionCall, ctx, g):
    # TODO: remove this limit and use stack
    assert len(f.args) < len(g.call_registers), 'too many arguments'



@@ 604,6 700,10 @@ def _(f: FunctionCall, ctx, g: CodeGenBuffer):

    g.call(f.name)

    # restore spilled registers
    for r in g.call_registers[:len(f.args)][::-1]: g.pop(r)
'''

if __name__ == "__main__":
    import doctest


M test_compiler.py => test_compiler.py +37 -1
@@ 149,12 149,48 @@ else:
                 ('=', 'x', ('a', 'tmp8'))])
    '''

    def test_pack_modrm(self):
        import compiler
        def bin(x):
            return format(x, '#010b').replace('0b', '')
        self.assertEqual(bin(ord(compiler._pack_modrm(4, 2, '*+disp8'))), '01' '100' '010')
        self.assertEqual(bin(ord(compiler._pack_modrm(1, 1, '*'))),       '00' '001' '001')
        self.assertEqual(bin(ord(compiler._pack_modrm(7, 7, 'direct'))),  '11' '111' '111')

    def nasm_assemble(self, code: bytes):
        import tempfile
        import subprocess
        import binascii

        with tempfile.NamedTemporaryFile(delete=False) as f:
            f.write(b'BITS 64\n')
            f.write(code)

        p = subprocess.Popen(f'nasm -f bin -o {f.name}.out {f.name}'.split(),
                stdout = subprocess.PIPE,
                stderr = subprocess.PIPE)

        stdout, stderr = p.communicate()

        if stderr:
            raise Exception(stdout, stderr)

        with open(f'{f.name}.out', 'rb') as f:
            return f.read()

    def test_emit(self):
        self.assertEqual(emit('rax <- rcx'), self.nasm_assemble(b'mov rax, rcx')[1:])
        self.assertEqual(emit('rax <- rcx'), b'\x89\xc8')
        self.assertEqual(emit('rcx <- rax'), b'\x89\xc1')

    '''
    def test_codegen(self):
        t = If(FunctionCall(FunctionCall('b', ()), 2), Return(1), Return(2))
        g = CodeGenBuffer()
        g = ''
        code_gen(normalize_stmt(t), None, g)
        import pprint
        pprint.pprint(g.buffer)
    '''

import doctest
def load_tests(loader, tests, ignore):