~robin_jadoul/binja_poke_vm

be79d3d7bcf3d004425e3dbb198a9b6e8881f87b — Robin Jadoul 3 years ago fb852c8
This is great now, works for the last VM version I have
1 files changed, 174 insertions(+), 29 deletions(-)

M __init__.py
M __init__.py => __init__.py +174 -29
@@ 1,5 1,6 @@
from binaryninja.architecture import Architecture
from binaryninja.types import Type
from binaryninja.callingconvention import CallingConvention
from binaryninja.enums import BranchType
from binaryninja.function import (
        RegisterInfo,
        InstructionInfo,


@@ 9,34 10,94 @@ from binaryninja.function import (
        IntrinsicInput,
        )
from binaryninja.lowlevelil import ILRegister
from binaryninja.types import Type

import struct

def unpack4(x):
    return struct.unpack("<L", x)[0]
    return struct.unpack("<l", x)[0]

def unpack8(x):
    return struct.unpack("<Q", x)[0]
    return struct.unpack("<q", x)[0]

class Instr:
    # r = register
    # i = immediate
    # XY = (register, immediate) addressing pair
    # a = address
    ops = {
            0: ("add", "rrr"),
            1: ("read", "r"),
            2: ("write", "r"),
            3: ("movi", "ri"),
            4: ("xor", "rrr"),
            9: ("exit", "r"),
            0x0: ("add", "rrr"),
            0x1: ("addi", "rri"),
            0x2: ("sub", "rrr"),
            0x3: ("read", "r"),
            0x4: ("write", "r"),
            0x5: ("movi", "ri"),
            0x6: ("xor", "rrr"),
            0x7: ("load", "rXY"),
            0x8: ("loadb", "rXY"),
            0x9: ("store", "XYr"),
            0xa: ("storeb", "XYr"),
            0xb: ("jmp", "a"),
            0xc: ("jr", "r"),
            0xd: ("jeq", "rra"),
            0xe: ("jb", "rra"),
            0xf: ("exit", "r"),
            0x10: ("nop", ""),
            }

    def __init__(self, data):
    def __init__(self, data, addr):
        self.opcode = unpack4(data[:4])
        self.op, self.argtypes = Instr.ops[self.opcode]
        self.args = []
        self.size = 0x1c
        self.addr = addr
        self.args = []
        for i, t in enumerate(self.argtypes):
            s = 4 + 8 * i
            self.args.append((t, unpack8(data[s:s+8])))

        ext, extinstr = self.can_continue(data)
        if extinstr == self.op:
            return
        next = [Instr(data[self.size * i:self.size * (i + 1)], addr + self.size * i) for i in range(1, ext)]
        self.attempt_extend(next)

    def can_be_call(self):
        if self.op == "addi" and self.args == [("r", 10), ("r", 11), ("i", self.size)]:
            return (2, "call")
        return (1, self.op)

    def can_be_pop(self):
        if self.op == "load" and self.args[1:] == [("X", 8), ("Y", 0)]:
            return (2, "pop")
        return (1, self.op)

    def can_be_push(self):
        if self.op == "addi" and self.args == [("r", 8), ("r", 8), ("i", -8)]:
            return (2, "push")
        return (1, self.op)

    def can_continue(self, data):
        navail = len(data) // self.size
        for f in [self.can_be_call, self.can_be_push, self.can_be_pop]:
            needed, newinstr = f()
            if newinstr != self.op and needed <= navail:
                return (needed, newinstr)
        return (1, self.op)

    def attempt_extend(self, next):
        if self.can_be_call()[1] == "call" and next[0].op == "jmp":
            self.op = "call"
            self.args = next[0].args
            self.size += next[0].size
        elif self.can_be_pop()[1] == "pop" and next[0].op == "addi" and next[0].args == [("r", 8), ("r", 8), ("i", 8)]:
            self.op = "pop"
            self.args = [self.args[0]]
            self.size += next[0].size
        elif self.can_be_push()[1] == "push" and next[0].op == "store" and next[0].args[:-1] == [("X", 8), ("Y", 0)]:
            self.op = "push"
            self.args = next[0].args[-1:]
            self.size += next[0].size

    def get_text(self):
        ITT = InstructionTextToken
        ITTType = InstructionTextTokenType


@@ 46,44 107,121 @@ class Instr:
                tokens.append(ITT(ITTType.RegisterToken, f"R{v}"))
            elif t == "i":
                tokens.append(ITT(ITTType.IntegerToken, str(v), v))
            elif t == "a":
                tokens.append(ITT(ITTType.PossibleAddressToken, str(v), v))
            elif t == "X":
                tokens.append(ITT(ITTType.BeginMemoryOperandToken, "["))
                tokens.append(ITT(ITTType.RegisterToken, f"R{v}"))
                continue
            elif t == "Y":
                if v != 0:
                    tokens.append(ITT(ITTType.TextToken, " + "))
                    tokens.append(ITT(ITTType.IntegerToken, str(v), v))
                tokens.append(ITT(ITTType.EndMemoryOperandToken, "]"))
            else:
                assert False
            tokens.append(ITT(ITTType.OperandSeparatorToken, ", "))
        return tokens[:-1]

    def get_info(self):
        info = InstructionInfo()
        info.length = self.size

        op = self.op
        if op == "jmp":
            info.add_branch(BranchType.UnconditionalBranch, self.args[0][1])
        elif op == "jr":
            if self.args[0][1] == 10:
                info.add_branch(BranchType.FunctionReturn)
            else:
                info.add_branch(BranchType.IndirectBranch)
        elif op in ["jeq", "jb"]:
            info.add_branch(BranchType.TrueBranch, self.args[2][1])
            info.add_branch(BranchType.FalseBranch, self.addr + self.size)
        elif op == "call":
            info.add_branch(BranchType.CallDestination, self.args[0][1])
            info.add_branch(BranchType.UnconditionalBranch, self.addr + 0x1c * 2)

        return info

    def get_label(self, il, addr):
        if (l := il.get_label_for_address(arch, addr)) is not None:
            return l
        il.add_label_for_address(arch, addr)
        return self.get_label(il, addr)

    def get_llil(self, il):
        args = []
        for t, v in self.args:
            if t == "r":
                args.append(il.reg(8, f"R{v}"))
            elif t == "i":
            if t in "rX":
                if v == 11:
                    args.append(il.const(8, self.addr + self.size))
                else:
                    args.append(il.reg(8, f"R{v}"))
            elif t in "iaY":
                args.append(il.const(8, v))
            else:
                assert False
        if self.op in ["add", "xor", "movi"]:
            f = {"add": il.add, "xor": il.xor_expr, "movi": lambda _, a: a}[self.op]
        if self.op in ["add", "addi", "sub", "xor", "movi"]:
            f = {
                    "add": il.add,
                    "addi": il.add,
                    "sub": il.sub,
                    "xor": il.xor_expr,
                    "movi": lambda _, a: a
                }[self.op]
            v = f(8, *args[1:])
            return [il.set_reg(8, f"R{self.args[0][1]}", v)]
        elif self.op == "read":
            return [il.intrinsic([ILRegister(Architecture["poke"], Architecture["poke"].get_reg_index(f"R{self.args[0][1]}"))], "read", [])]
            return [il.intrinsic([ILRegister(arch, arch.get_reg_index(f"R{self.args[0][1]}"))], "read", [])]
        elif self.op == "write":
            return [il.intrinsic([], "write", args)]
        elif self.op.startswith("load"):
            dest = f"R{self.args[0][1]}"
            src = il.add(8, args[1], args[2])
            size = 1 if self.op.endswith("b") else 8
            return [il.set_reg(size, dest, il.load(size, src))]
        elif self.op.startswith("store"):
            dest = il.add(8, args[0], args[1])
            src = args[2]
            size = 1 if self.op.endswith("b") else 8
            return [il.store(size, dest, src)]
        elif self.op == "push":
            return [il.push(8, args[0])]
        elif self.op == "pop":
            return [il.set_reg(8, f"R{self.args[0][1]}", il.pop(8))]
        elif self.op == "jmp":
            return [il.jump(args[0])]
        elif self.op == "jr":
            if self.args[0][1] == 10:
                return [il.ret(args[0])]
            return [il.jump(args[0])]
        elif self.op == "jeq":
            cond = il.compare_equal(8, args[0], args[1])
            return [il.if_expr(cond, self.get_label(il, self.args[2][1]), self.get_label(il, self.addr + self.size))]
        elif self.op == "jb":
            cond = il.compare_unsigned_less_than(8, args[0], args[1])
            return [il.if_expr(cond, self.get_label(il, self.args[2][1]), self.get_label(il, self.addr + self.size))]
        elif self.op == "call":
            return [il.set_reg(8, "R10", il.const(8, self.addr + 0x1c)), il.call(args[0])]
        elif self.op == "exit":
            return [il.intrinsic([], "exit", args), il.no_ret()]
        elif self.op == "nop":
            return [il.nop()]
        else:
            return [il.unimplemented()]

class Poke(Architecture):
    name = 'poke'
    instr_alignment = 1
    max_instr_length = 0x1c
    max_instr_length = 0x1c * 2
    address_size = 8
    default_int_size = 8

    regs = {**{f"R{i}": RegisterInfo(f"R{i}", 8) for i in range(7)},
                "R7": RegisterInfo("R7", 8),
    regs = {**{f"R{i}": RegisterInfo(f"R{i}", 8) for i in range(0xb)},
                "R11": RegisterInfo("R11", 8),
                "PC": RegisterInfo("R7", 8),
                "SP": RegisterInfo("SP", 8),
                "SP": RegisterInfo("R8", 8),
            }
    stack_pointer = "SP"



@@ 94,22 232,29 @@ class Poke(Architecture):
                }

    def get_instruction_info(self, data, addr):
        instr = Instr(data[:self.max_instr_length])

        res = InstructionInfo()
        res.length = instr.size
        # No branching yet
        return res
        try:
            instr = Instr(data[:self.max_instr_length], addr)
            return instr.get_info()
        except:
            return None

    def get_instruction_text(self, data, addr):
        instr = Instr(data[:self.max_instr_length])
        instr = Instr(data[:self.max_instr_length], addr)
        return instr.get_text(), instr.size

    def get_instruction_low_level_il(self, data, addr, il):
        instr = Instr(data[:self.max_instr_length])
        instr = Instr(data[:self.max_instr_length], addr)
        for x in instr.get_llil(il):
            il.append(x)
        return instr.size

class CC(CallingConvention):
    callee_saved_regs = ["R0", "R1", "R2", "R3", "R10"]
    int_arg_regs = ["R4", "R5"]
    int_return_reg = "R4"
    high_int_return_reg = "R5"

Poke.register()
arch = Architecture['poke']
arch.register_calling_convention(CC(arch, "poke"))
arch.standalone_platform.default_calling_convention = arch.calling_conventions['poke']