~martijnbraam/bare-py

ade421e394e5a7a18b0471acb46743d0790deabc — Martijn Braam 8 months ago b529569
Made dump command print more details
2 files changed, 148 insertions(+), 12 deletions(-)

M bare/__main__.py
M bare/dump.py
M bare/__main__.py => bare/__main__.py +3 -1
@@ 80,7 80,9 @@ def codegen(schema, output, indent, skip=None):

    if indent != '\t':
        result = result.replace('\t', indent)
    output.write(result)
    if output is not None:
        output.write(result)
    return result


def main():

M bare/dump.py => bare/dump.py +145 -11
@@ 1,9 1,148 @@
import argparse
import tempfile
import importlib.machinery
import base64
import sys
import imp
import shutil
import textwrap
from enum import EnumMeta
from itertools import zip_longest as zipl

from bare import _unpack_primitive
from bare.__main__ import codegen
from bare.bare_ast import BareType, UnionType, BarePrimitive, TypeKind, StructType, MapType, NamedType, ArrayType, \
    OptionalType


class Line:
    def __init__(self, data, annotation, decoded, indent):
        self.data = data
        self.annotation = ('|' * indent) + annotation
        self.decoded = decoded
        self.indent = indent
        self.hex = ''
        for i, b in enumerate(data):
            self.hex += '{:02X} '.format(b)
            if i % 8 == 7:
                self.hex += '\n'

    def __str__(self):
        size = shutil.get_terminal_size((120, 20))
        hexwidth = 8 * 3
        annotatewidth = 24
        decodedwidth = size.columns - hexwidth - annotatewidth - 4

        hexlines = self.hex.splitlines()
        annotatelines = [self.annotation]
        decodedlines = textwrap.wrap(self.decoded, decodedwidth)

        maxlines = max(len(hexlines), len(annotatelines), len(decodedlines))
        if len(annotatelines) < maxlines:
            annotatelines += ['|' * self.indent] * (maxlines - len(annotatelines))

        result = ''
        for hex, ann, dec in zipl(hexlines, annotatelines, decodedlines, fillvalue=''):
            result += ann.ljust(annotatewidth) + '  ' + hex.ljust(hexwidth) + '  ' + dec + "\n"
        return result


def import_schema(schema):
    code = codegen(schema, None, '\t')
    schema_module = imp.new_module('schema')
    exec(code, schema_module.__dict__)
    sys.modules['schema'] = schema_module
    return schema_module


def dump(data, type, module):
    node = type._ast
    label = None
    offset = 0
    nodelist = []
    indent = 0
    while True:
        if hasattr(node, '_ast') and isinstance(node._ast, StructType):
            yield Line(b'', 'struct {}'.format(label), '', indent)
            end_offset = offset

            newnodes = []
            for fieldname in node._ast.fields:
                newnodes.append((node._ast.fields[fieldname], fieldname))
            nodelist = newnodes + [(None, None)] + nodelist
            indent += 1
        elif hasattr(node, '_ast') and isinstance(node._ast, BarePrimitive):
            nodelist = [(node._ast, label)] + nodelist
        elif isinstance(node, UnionType):
            tag, end_offset = _unpack_primitive(BarePrimitive(TypeKind.UINT), data, offset)
            for type in node.types:
                if type.value == tag:
                    break
            else:
                raise ValueError("Cannot find type for tag {}".format(tag))
            yield Line(data[offset:end_offset], 'Union', 'tag = {} ({})'.format(tag, type.type.name), indent)
            nodelist = [(getattr(module, type.type.name), type.type.name), (None, None)] + nodelist
            indent += 1
        elif isinstance(node, BarePrimitive):
            value, end_offset = _unpack_primitive(node, data, offset)
            if node.length is None:
                dec = '{} = {}'.format(node.type.name, value)
            else:
                dec = '{}<{}> = {}'.format(node.type.name, node.length, value)
            yield Line(data[offset:end_offset], label, dec, indent)
        elif isinstance(node, NamedType):
            referenced = getattr(module, node.name)
            if isinstance(referenced, EnumMeta):
                value, end_offset = _unpack_primitive(BarePrimitive(TypeKind.UINT), data, offset)
                enum = referenced(value)
                yield Line(data[offset:end_offset], label, str(enum), indent)
            else:
                newnode = (referenced, label)
                nodelist = [newnode] + nodelist
                end_offset = offset
        elif isinstance(node, ArrayType):
            if node.length is None:
                length, end_offset = _unpack_primitive(BarePrimitive(TypeKind.UINT), data, offset)
                dec = '[]{}, length = {}'.format(node.subtype.name, length)
            else:
                end_offset = offset
                length = node.length
                dec = '[{}]{}'.format(length, node.subtype.type.name)
            yield Line(data[offset:end_offset], label, dec, indent)
            if length > 0:
                newnodes = []
                for i in range(0, length):
                    newnodes.append((node.subtype, str(i)))
                nodelist = newnodes + [(None, None)] + nodelist
                indent += 1
        elif isinstance(node, OptionalType):
            exists, end_offset = _unpack_primitive(BarePrimitive(TypeKind.UINT), data, offset)
            yield Line(data[offset:end_offset], label, 'optional, value = {}'.format(exists != 0), indent)
            if exists:
                nodelist = [(node.subtype, 'value'), (None, None)] + nodelist
                indent += 1
        elif isinstance(node, MapType):
            length, end_offset = _unpack_primitive(BarePrimitive(TypeKind.UINT), data, offset)
            dec = 'map[{}]{}, length = {}'.format(node.keytype.type.name, node.valuetype.type.name, length)
            yield Line(data[offset:end_offset], label, dec, indent)
            if length > 0:
                newnodes = []
                for i in range(0, length):
                    newnodes.append((node.keytype, 'key {}'.format(i)))
                    newnodes.append((node.valuetype, 'value {}'.format(i)))
                nodelist = newnodes + [(None, None)] + nodelist
                indent += 1

        elif node is None:
            indent -= 1
        else:
            break

        if len(nodelist) == 0:
            return

        offset = end_offset
        node = nodelist[0][0]
        label = nodelist[0][1]
        nodelist = nodelist[1:]


def main():


@@ 20,16 159,11 @@ def main():
        with open(args.message, 'rb') as handle:
            message = handle.read()

    with tempfile.NamedTemporaryFile(suffix='.py', mode='w') as output:
        codegen(args.schema.read(), output, '    ')
        output.flush()
        schema = importlib.machinery.SourceFileLoader('schema', output.name).load_module()

        type = getattr(schema, args.type)
        result = type.unpack(message)
    schema = import_schema(args.schema.read())
    type = getattr(schema, args.type)

    print(result.__class__.__name__)
    print(vars(result))
    for line in dump(message, type, schema):
        print(line, end='')


if __name__ == '__main__':