~nch/python-compiler

939533957d953a2a6b53c012dfef7e740c08977c — nc 1 year, 11 months ago c91dde4
clean up nodeclass() and add tests
2 files changed, 53 insertions(+), 13 deletions(-)

M compiler.py
M test_compiler.py
M compiler.py => compiler.py +13 -13
@@ 353,26 353,26 @@ class Block(list):

# took inspiration from namedtuple:
# https://github.com/python/cpython/blob/58ccd201fa74287ca9293c03136fcf1e19800ef9/Lib/collections/__init__.py#L290
def nodeclass(name, fields, default_values=[]):
    if type(fields) == str:
        fields = fields.replace(',', ' ').split()
def nodeclass(name, fields, hole_values=[]):
    fields = fields.replace(',', ' ').split() if type(fields) == str else fields
    class_namespace = {}
    for i, f in enumerate(fields):
        if f == '_': assert i < len(default_values), f"default value for {i} is not passed!"
        if f.startswith('*'): assert i == len(fields) - 1, "splat arg must be last field"
        if f == '_':
            assert i < len(hole_values), f"default value for {i} is not passed!"
            continue
        elif f.startswith('*'):
            assert i == len(fields) - 1, "splat arg must be last field"
            class_namespace[f.lstrip('*')] = property(lambda self: self[i:], doc=f'alias for elements [{i}:]')
        else:
            class_namespace[f] = property(itemgetter(i), doc=f'alias for element at {i}')

    def __new__(cls, *args):
        for i, f in enumerate(fields):
            if f == '_': args = args[:i] + (default_values[i],) + args[i:]
            if f == '_': args = args[:i] + (hole_values[i],) + args[i:]
        return tuple.__new__(cls, args)

    class_namespace = {'__new__': __new__}
    class_namespace['__new__'] = __new__

    for i, f in enumerate(fields):
        if f == '_': continue
        if f.startswith('*'):
            class_namespace[f.lstrip('*')] = property(lambda self: self[i:], doc=f'alias for elements [{i}:]')
        else:
            class_namespace[f] = property(itemgetter(i), doc=f'alias for element at {i}')
    return type(name, (tuple,), class_namespace)

FunctionCall = nodeclass('FunctionCall', 'name *args')

M test_compiler.py => test_compiler.py +40 -0
@@ 67,6 67,46 @@ else:
        b2 = Block('a', Block('b'), 'c')
        self.assertEqual(b2, ['a', 'b', 'c'])

    def test_nodeclass(self):
        A = nodeclass('A', 'a,b,c')
        a = A(1,2,3)
        self.assertEqual(a.a, 1)
        self.assertEqual(a.b, 2)
        self.assertEqual(a.c, 3)
        self.assertEqual(a[0], 1)
        self.assertEqual(a[1], 2)
        self.assertEqual(a[2], 3)

        B = nodeclass('B', 'a _ c', [None, 5])
        b = B(1, 3)
        self.assertEqual(b.a, 1)
        self.assertEqual(b.c, 3)
        self.assertEqual(b[0], 1)
        self.assertEqual(b[1], 5)
        self.assertEqual(b[2], 3)

        C = nodeclass('C', 'a b c *z')
        c = C(1,2,3)
        self.assertEqual(c.a, 1)
        self.assertEqual(c.b, 2)
        self.assertEqual(c.c, 3)
        self.assertEqual(c.z, ())
        self.assertEqual(c[0], 1)
        self.assertEqual(c[1], 2)
        self.assertEqual(c[2], 3)

        D = nodeclass('D', 'a _ c *z', [None, 44])
        d = D(1, 3, 4, 5, 6)
        self.assertEqual(d.a, 1)
        self.assertEqual(d.c, 3)
        self.assertEqual(d.z, (4,5,6))
        self.assertEqual(d[0], 1)
        self.assertEqual(d[1], 44)
        self.assertEqual(d[2], 3)
        self.assertEqual(d[3], 4)
        self.assertEqual(d[4], 5)
        self.assertEqual(d[5], 6)

    def test_normalize(self):
        tree1 = FunctionCall('a', 'b', 'c')
        self.assertEqual(normalize_expr(tree1)[0], ('a', 'b', 'c'))