~nch/python-compiler

59466df4110716f0423be4359e7a2c663a8a0635 — nc 1 year, 6 months ago 7fd8084
fix offset calculation
2 files changed, 21 insertions(+), 8 deletions(-)

M compiler.py
M test_compiler.py
M compiler.py => compiler.py +11 -8
@@ 679,18 679,21 @@ def compute_offset(label):
    def f(loc):
        if label not in labels:
            raise Exception('undefined label', label)
        return labels[label] - loc
        return labels[label] - loc - 2
    return f

def pass2(instructions: List[Union[bytes, Callable, tuple]]) -> bytes:
def pass2(instructions: List[Union[bytes, tuple]]) -> bytes:
    '''
    The second pass of the assembler that fills in jump locations and combines the final
    instruction bytestream
    '''
    r = b''
    for i in instructions:
        if callable(i):
            r += i(len(r))
        if isinstance(i, tuple):
            bytelen, f = i
            b = f(len(r))
            assert(len(b) == bytelen) # sanity check
            r += b
        elif isinstance(i, bytes):
            r += i
        else:


@@ 703,8 706,8 @@ ops = [
        (('ret',), lambda _: b'\xc3'),

        # FIXME: calculate distance and emit correct opcode based on whether it's a short or long jump
        (('j', label_p), lambda _, l: lambda x: b'\xeb' + pack8(compute_offset(l)(x) - 2, signed=True)),
        (('jne', label_p), lambda _, l: lambda x: b'\x75' + pack8(compute_offset(l)(x) + 2, signed=True)),
        (('j', label_p), lambda _, l: (2, lambda x: b'\xeb' + pack8(compute_offset(l)(x), signed=True))),
        (('jne', label_p), lambda _, l: (2, lambda x: b'\x75' + pack8(compute_offset(l)(x), signed=True))),

        ## COMPARISONS
        (('cmp', reg32_p, or_p(reg32_p, mem_p(reg32_p))), lambda _, r1, x: b'\x39' + modrm(r1, x)),


@@ 766,7 769,7 @@ def _(i: int, ctx, g): # type: ignore

labels = {}
def emit_label(name, g):
    labels[name] = len(g)
    labels[name] = sum(x[0] if isinstance(x, tuple) else len(x) for x in g)

@code_gen.register(If) # type: ignore
def _(ifstmt, ctx, g):


@@ 783,7 786,7 @@ def _(ifstmt, ctx, g):
    emit_label(false_label, g)
    code_gen(ifstmt.otherwise, ctx, g)

    g.append(emit_label(end_label))
    emit_label(end_label, g)

@code_gen.register(Block) # type: ignore
def _(block: Block, ctx, g): # type: ignore

M test_compiler.py => test_compiler.py +10 -0
@@ 304,11 304,15 @@ on some lines''')

    def test_labels(self):
        r1 = []
        r1.append(emit('j l'))
        r1.append(emit('j l'))
        emit_label('l', r1)
        r1.append(emit('rax <- 1'))
        r1.append(emit('add eax eax'))
        r1.append(emit('j l'))
        self.assertEqual(pass2(r1), self.nasm_assemble(b'''
            jmp l
            jmp l
            l:
            mov eax, 1
            add eax, eax


@@ 317,16 321,22 @@ on some lines''')

        r2 = []
        r2.append(emit('rax <- 1'))
        r2.append(emit('jne l'))
        r2.append(emit('jne l'))
        emit_label('l', r2)
        r2.append(emit('add eax eax'))
        r2.append(emit('cmp rax 2000'))
        r2.append(emit('jne l'))
        r2.append(emit('jne l'))
        self.assertEqual(pass2(r2), self.nasm_assemble(b'''
            mov eax, 1
            jne l
            jne l
            l:
            add eax, eax
            cmp rax, 2000
            jne l
            jne l
            '''))

    '''