~robin_jadoul/binja_poke_vm

03c9d3df0f3adbe72661d10e329bba8752b179e5 — Robin Jadoul 3 years ago be79d3d
Restructure towards a more generic potential library/tool structure, add convert_to_nop, 'normal' disassembly and assembly
2 files changed, 249 insertions(+), 85 deletions(-)

M __init__.py
A disassembly.py
M __init__.py => __init__.py +6 -85
@@ 12,92 12,9 @@ from binaryninja.function import (
from binaryninja.lowlevelil import ILRegister
from binaryninja.types import Type

import struct

def unpack4(x):
    return struct.unpack("<l", x)[0]

def unpack8(x):
    return struct.unpack("<q", x)[0]

class Instr:
    # r = register
    # i = immediate
    # XY = (register, immediate) addressing pair
    # a = address
    ops = {
            0x0: ("add", "rrr"),
            0x1: ("addi", "rri"),
            0x2: ("sub", "rrr"),
            0x3: ("read", "r"),
            0x4: ("write", "r"),
            0x5: ("movi", "ri"),
            0x6: ("xor", "rrr"),
            0x7: ("load", "rXY"),
            0x8: ("loadb", "rXY"),
            0x9: ("store", "XYr"),
            0xa: ("storeb", "XYr"),
            0xb: ("jmp", "a"),
            0xc: ("jr", "r"),
            0xd: ("jeq", "rra"),
            0xe: ("jb", "rra"),
            0xf: ("exit", "r"),
            0x10: ("nop", ""),
            }

    def __init__(self, data, addr):
        self.opcode = unpack4(data[:4])
        self.op, self.argtypes = Instr.ops[self.opcode]
        self.size = 0x1c
        self.addr = addr
        self.args = []
        for i, t in enumerate(self.argtypes):
            s = 4 + 8 * i
            self.args.append((t, unpack8(data[s:s+8])))

        ext, extinstr = self.can_continue(data)
        if extinstr == self.op:
            return
        next = [Instr(data[self.size * i:self.size * (i + 1)], addr + self.size * i) for i in range(1, ext)]
        self.attempt_extend(next)

    def can_be_call(self):
        if self.op == "addi" and self.args == [("r", 10), ("r", 11), ("i", self.size)]:
            return (2, "call")
        return (1, self.op)

    def can_be_pop(self):
        if self.op == "load" and self.args[1:] == [("X", 8), ("Y", 0)]:
            return (2, "pop")
        return (1, self.op)

    def can_be_push(self):
        if self.op == "addi" and self.args == [("r", 8), ("r", 8), ("i", -8)]:
            return (2, "push")
        return (1, self.op)

    def can_continue(self, data):
        navail = len(data) // self.size
        for f in [self.can_be_call, self.can_be_push, self.can_be_pop]:
            needed, newinstr = f()
            if newinstr != self.op and needed <= navail:
                return (needed, newinstr)
        return (1, self.op)

    def attempt_extend(self, next):
        if self.can_be_call()[1] == "call" and next[0].op == "jmp":
            self.op = "call"
            self.args = next[0].args
            self.size += next[0].size
        elif self.can_be_pop()[1] == "pop" and next[0].op == "addi" and next[0].args == [("r", 8), ("r", 8), ("i", 8)]:
            self.op = "pop"
            self.args = [self.args[0]]
            self.size += next[0].size
        elif self.can_be_push()[1] == "push" and next[0].op == "store" and next[0].args[:-1] == [("X", 8), ("Y", 0)]:
            self.op = "push"
            self.args = next[0].args[-1:]
            self.size += next[0].size
from . import disassembly

class Instr(disassembly.Instr):
    def get_text(self):
        ITT = InstructionTextToken
        ITTType = InstructionTextTokenType


@@ 248,6 165,10 @@ class Poke(Architecture):
            il.append(x)
        return instr.size

    def convert_to_nop(self, data, addr):
        assert len(data) % 0x1c == 0
        return (b"\x10" + b"\0" * 0x1b) * (len(data) // 0x1c)

class CC(CallingConvention):
    callee_saved_regs = ["R0", "R1", "R2", "R3", "R10"]
    int_arg_regs = ["R4", "R5"]

A disassembly.py => disassembly.py +243 -0
@@ 0,0 1,243 @@


import struct

def unpack4(x):
    return struct.unpack("<l", x)[0]

def pack4(v):
    return struct.pack("<l", v)

def unpack8(x):
    return struct.unpack("<q", x)[0]

def pack8(v):
    return struct.pack("<q", v)

class Instr:
    # r = register
    # i = immediate
    # XY = (register, immediate) addressing pair
    # a = address
    ops = {
            0x0: ("add", "rrr"),
            0x1: ("addi", "rri"),
            0x2: ("sub", "rrr"),
            0x3: ("read", "r"),
            0x4: ("write", "r"),
            0x5: ("movi", "ri"),
            0x6: ("xor", "rrr"),
            0x7: ("load", "rXY"),
            0x8: ("loadb", "rXY"),
            0x9: ("store", "XYr"),
            0xa: ("storeb", "XYr"),
            0xb: ("jmp", "a"),
            0xc: ("jr", "r"),
            0xd: ("jeq", "rra"),
            0xe: ("jb", "rra"),
            0xf: ("exit", "r"),
            0x10: ("nop", ""),

            # fake opcodes
            "call": ("call", "a"),
            "push": ("push", "r"),
            "pop": ("pop", "r")
            }
    readable = {"R8": "SP", "R9": "BP", "R10": "RA", "R11": "PC"}
    readableR = {"SP": "R8", "BP": "R9", "RA": "R10", "PC": "R11"}

    def __init__(self, data, addr):
        self.opcode = unpack4(data[:4])
        self.addr = addr
        self.args = []

        if self.opcode not in Instr.ops:
            self.op = "???"
            self.size = 1
            self.args.append(("i", data[0]))
            return 

        self.op, self.argtypes = Instr.ops[self.opcode]
        self.size = 0x1c
        for i, t in enumerate(self.argtypes):
            s = 4 + 8 * i
            self.args.append((t, unpack8(data[s:s+8])))

        ext, extinstr = self.can_continue(data)
        if extinstr == self.op:
            return
        next = [Instr(data[self.size * i:self.size * (i + 1)], addr + self.size * i) for i in range(1, ext)]
        self.attempt_extend(next)

    @classmethod
    def parse(cls, instr, addr):
        instr = instr.lower()
        instr = instr.replace("\t", " ")
        while "  " in instr: instr = instr.replace("  ", " ")
        instr = instr.replace(", ", ",").replace("[ ", "[").replace(" ]", "]").replace("< ", "<").replace(" >", ">")

        o = cls(b"\xff" * 0x1c, addr)
        op = instr.split(" ", 1)[0]
        args = instr.split(" ", 1)[1].split(",")
        for c, (opp, ts) in Instr.ops.items():
            if op == opp:
                o.opcode = c
                o.op = op
                o.addr = addr
                o.args = []
                o.argtypes = ts
                if isinstance(c, str):
                    o.size = 2*0x1c
                else:
                    o.size = 0x1c
                    
                can_skip_Y = False
                it = ia = 0
                while it < len(ts):
                    if ia >= len(args):
                        raise RuntimeError(f"Not enough arguments for opcode {op}")

                    t = ts[it]
                    a = args[ia]

                    if t in "rX":
                        if t == "X":
                            skipY = False
                            if not a.startswith("["):
                                raise RuntimeError("Expected a memory operand")
                            a = a[1:]
                            if a.endswith("]"):
                                a = a[:-1]
                                it += 1
                                skipY = True

                        a = Instr.readableR.get(a.upper(), a.upper())
                        if not a.startswith("R") or not a[1:] or not a[1:].isdigit() or not 0 <= int(a[1:]) <= 11:
                            raise RuntimeError(f"Register {a} is not a proper register")
                        o.args.append((t, int(a[1:])))
                        
                        if t == "X" and skipY:
                            o.args.append(("Y", 0))
                    elif t in "iY":
                        if t == "Y":
                            if not a.endswith("]"):
                                raise RuntimeError("Unclosed memory operand")
                            a = a[:-1]
                        try:
                            a = int(a, 0)
                        except ValueError:
                            raise RuntimeError(f"Invalid immediate value: {a}")

                        o.args.append((t, a))
                    elif t == "a":
                        a = a.strip("<>")
                        try:
                            a = int(a, 16)
                        except:
                            raise RuntimeError(f"Invalid address: {a}")
                        o.args.append((t, a))
                    it += 1
                    ia += 1
                break
        else:
            raise RuntimeError(f"Unknown operation: {op}")
        return o
            
    def can_be_call(self):
        if self.op == "addi" and self.args == [("r", 10), ("r", 11), ("i", self.size)]:
            return (2, "call")
        return (1, self.op)

    def can_be_pop(self):
        if self.op == "load" and self.args[1:] == [("X", 8), ("Y", 0)]:
            return (2, "pop")
        return (1, self.op)

    def can_be_push(self):
        if self.op == "addi" and self.args == [("r", 8), ("r", 8), ("i", -8)]:
            return (2, "push")
        return (1, self.op)

    def can_continue(self, data):
        navail = len(data) // self.size
        for f in [self.can_be_call, self.can_be_push, self.can_be_pop]:
            needed, newinstr = f()
            if newinstr != self.op and needed <= navail:
                return (needed, newinstr)
        return (1, self.op)

    def attempt_extend(self, next):
        if self.can_be_call()[1] == "call" and next[0].op == "jmp":
            self.op = self.opcode = "call"
            self.args = next[0].args
            self.size += next[0].size
        elif self.can_be_pop()[1] == "pop" and next[0].op == "addi" and next[0].args == [("r", 8), ("r", 8), ("i", 8)]:
            self.op = self.opcode = "pop"
            self.args = [self.args[0]]
            self.size += next[0].size
        elif self.can_be_push()[1] == "push" and next[0].op == "store" and next[0].args[:-1] == [("X", 8), ("Y", 0)]:
            self.op = self.opcode = "push"
            self.args = next[0].args[-1:]
            self.size += next[0].size

    def assemble(self):
        if isinstance(self.opcode, str):
            if self.op == "call":
                instrs = [
                            "addi RA, PC, 0x1c",
                            f"jmp {self.args[0][1]}"
                        ]
            elif self.op == "push":
                instrs = [
                            "addi R8, R8, -8",
                            f"store [R8], R{self.args[0][1]}"
                        ]
            elif self.op == "pop":
                instrs = [
                            f"load R{self.args[0][1]}, [R8]",
                            "addi R8, R8, 8"
                        ]
            else:
                print(self.op)
                assert False
            res = b""
            addr = self.addr
            for unparsed in instrs:
                parsed = Instr.parse(unparsed, addr)
                res += parsed.assemble()
                addr += parsed.size
            return res
        else:
            return pack4(self.opcode) + b"".join(pack8(v) for t, v in (self.args + [("_", 0)] * 3)[:3])

    def __str__(self):
        args = []
        for t, v in self.args:
            if t == "i":
                args.append(str(v))
            elif t in "rX":
                name = f"R{v}"
                args.append(Instr.readable.get(name, name))
            elif t == "a":
                args.append(f"<{v:06x}>")
            elif t == "Y":
                if v:
                    args[-1] = f"[{args[-1]}, {v}]"
                else:
                    args[-1] = f"[{args[-1]}]"
        return f"{self.addr:06x} {self.op} {', '.join(args)}"

def disassemble(data, addr=0):
    instrs = []
    while len(data) > 0x1c:
        instrs.append(Instr(data, addr))
        addr += instrs[-1].size
        data = data[instrs[-1].size:]
    return instrs

def display(instrs):
    return "\n".join(map(str, instrs))

if __name__ == "__main__":
    import sys
    print(display(disassemble(open(sys.argv[1], "rb").read())))