~nch/glue

6503844b76e7bf0587c58a1c96c985498f44b3aa — nc 1 year, 3 months ago 06e401b
a few more dataflow attempts
3 files changed, 318 insertions(+), 0 deletions(-)

A dataflow1.py
A dataflow2.py
A test_dataflow1.py
A dataflow1.py => dataflow1.py +128 -0
@@ 0,0 1,128 @@
from typing import TypeVar, Generic, Optional, Dict, Any, Iterable, Tuple, Callable

from functools import reduce
import operator

# XXX this code is total crap. Fix it at some point.

T = TypeVar('T')

class Stream(Generic[T]):
    def __init__(self, *deps): # this is a *BAD* idea -- deps tracking should be automatic
        self.value = None
        self.first = None
        self.deps = deps

    def next(self) -> T:
        ...

    def step(self):
        self.value = self.next()
        if self.first is None: # XXX is subtly wrong if None is a valid value
            self.first = self.value
        return self.value

    def fby(self, expr):
        return FbyStream(self, expr)

class FbyStream(Stream):
    def __init__(self, first, rest):
        self._first = first
        self._rest = rest
        self._is_first = True
        super().__init__(self._rest)

    def next(self): # FIXME just... yuck
        if self._is_first:
            self._is_first = False
            return self._first.value
        else:
            return self._rest.value

class IterStream(Stream):
    def __init__(self, i: Iterable):
        self.iter = iter(i)
        super().__init__()

    def next(self):
        return next(self.iter)

class BinOps:
    def __add__(self, x):
        return Expr(operator.add, self, x)
    def __mul__(self, x):
        return Expr(operator.mul, self, x)

class Int(Stream[int], BinOps):
    def __init__(self, i):
        self.i = i
        super().__init__()

    def __str__(self):
        return f'Int({self.i})'

    def next(self):
        return self.i

class Expr(Stream):
    def __init__(self, op, *args):
        self.op = op
        self.args = args
        super().__init__(*self.args)

    def next(self):
        # TODO: Hmm... determine if this should be a.value or a.next
        vals = [a.value for a in self.args]
        if any(isinstance(x, _Var) for x in vals):
            return self # not reducible yet
        else:
            return reduce(self.op, vals)

    def __str__(self):
        s = ' '.join(map(str, self.args))
        return f'({self.op} {s})'

Env = Dict[str, Expr]

def step(env, v):
    # TODO: register all nodes and topological sort then don't do the
    # semi-lazy graph compute thing that's happening here
    processed = set()
    def step_inner(x):
        if x in processed:
            return

        for d in x.deps:
            step_inner(d)

        x.step()
        processed.add(x)

    for v in env.values():
        step_inner(v)

    return env[v]


class _Var(Stream, BinOps):
    def __init__(self, env, name):
        self.name = name
        self.env = env
        super().__init__()

    def __str__(self):
        return f'Var("{self.name}")'

    def bind(self, x):
        self.env[self.name] = x
        self.value = x.value

    def compute_next(self):
        if self.name in self.env:
            return self.env[self.name].next()
        else:
            return self # is this even correct? I think it is...

def create_env() -> Tuple[Env, Callable]: # -> env, Var class
    bindings: Env = {}
    return bindings, lambda varname: _Var(bindings, varname)

A dataflow2.py => dataflow2.py +112 -0
@@ 0,0 1,112 @@
import operator

PRINT_STYLE = 'lisp'

def node_initializer(initf):
    def f(self, *args, **kwargs):
        self.deps = []

        if not hasattr(node_initializer, 'all_nodes'):
            node_initializer.all_nodes = set()

        node_initializer.all_nodes.add(self)

        for a in args:
            if isinstance(a, Val):
                self.deps.append(a)

        return initf(self, *args, **kwargs)
    return f

def step_everything():
    stepped = set()
    def step(x):
        if x in stepped:
            return
        for y in x.deps:
            step(y)
        x.value = x.next()
        stepped.add(x)

    for node in node_initializer.all_nodes:
        step(node)

class Val: # just used for type checking
    pass

class BinOps:
    def __add__(self, x):
        return Expr(operator.add, self, x)
    def __mul__(self, x):
        return Expr(operator.mul, self, x)
    def __sub__(self, x):
        return Expr(operator.sub, self, x)

class Expr(Val, BinOps):
    @node_initializer
    def __init__(self, op, *args):
        self.op = op
        self.args = args

    def __str__(self):
        if PRINT_STYLE == 'py':
            s = ', '.join(map(str, self.args))
            return f'{self.op.__name__}({s})'
        elif PRINT_STYLE == 'lisp':
            s = ' '.join(map(str, self.args))
            return f'({self.op.__name__} {s})'

class Int(Val, BinOps):
    @node_initializer
    def __init__(self, i):
        self.i = i

    def __str__(self):
        if PRINT_STYLE == 'py':
            return f'Int({self.i})'
        elif PRINT_STYLE == 'lisp':
            return str(self.i)

class Var(Val, BinOps):
    @node_initializer
    def __init__(self, name):
        self.name = name

    def __str__(self):
        if PRINT_STYLE == 'py':
            return f"Var('{self.name}')"
        elif PRINT_STYLE == 'lisp':
            return self.name


'''
def lift(x):
    if isinstance(x, Val):
        return x # already lifted
    if isinstance(x, int):
        return Int(x)
    assert False, type(x)

def eval_expr(tree, bindings):
    def chase(x):
        if isinstance(x, Var) and x.name in bindings:
            return chase(bindings[x.name])
        else:
            return x

    if isinstance(tree, Expr):
        args = [eval_expr(a, bindings) for a in tree.args]
        if any(isinstance(a, Var) for a in args): # can't evaluate yet
            return Expr(tree.op, *[lift(a) for a in args])
        else:
            return tree.op(*args)
    if isinstance(tree, Int):
        return tree.i
    if isinstance(tree, Var):
        v = chase(tree)
        if isinstance(v, Expr):
            return eval_expr(v, bindings)
        else:
            return v
    assert False, f"shouldn't get here: {type(tree)}"
'''

A test_dataflow1.py => test_dataflow1.py +78 -0
@@ 0,0 1,78 @@
import unittest
from dataflow1 import *

class TestFRP(unittest.TestCase):
    def test_int(self):
        i = Int(1)
        self.assertEqual(i.step(), 1)
        self.assertEqual(i.step(), 1)

    def test_var(self):
        env, Var = create_env()
        v = Var('v')
        self.assertEqual(v.step(), v)
        v.bind(Int(1))
        self.assertEqual(v.step(), 1)

    def test_expr(self):
        env, Var = create_env()
        e = Var('e')
        e.bind(Int(1) + Int(1))
        self.assertEqual(step(env, 'e'), 2)
        self.assertEqual(step(env, 'e'), 2)
        self.assertEqual(step(env, 'e'), 2)

    def test_expr2(self):
        env, Var = create_env()
        n = Var('n')
        n.bind(Int(1))

        e = n + Int(1)
        self.assertEqual(e.value, 2)
        self.assertEqual(e.next(), 2)

    def test_fby(self):
        env, Var = create_env()
        n = Var('n')
        n.bind(Int(1).fby(Int(2)))
        self.assertEqual(n.value, 1)
        self.assertEqual(n.next(), 2)

        n.bind(Int(1).fby(Int(2).fby(Int(3))))
        self.assertEqual(n.value, 1)
        self.assertEqual(n.next(), 2)
        self.assertEqual(n.next(), 3)

    def test_basic_accessors(self):
        env, Var = create_env()
        i = IterStream([1, 2])
        self.assertEqual(i.step(), 1)
        self.assertEqual(i.first, 1)
        self.assertEqual(i.value, 1)
        self.assertEqual(i.next(), 2)

    def test_recursive(self):
        env, Var = create_env()
        n = Var('n')
        n.bind(Int(1).fby(n + Int(1)))

        self.assertEqual(n.value, 1)
        self.assertEqual(n.next(), 2)
        self.assertEqual(n.next(), 3)
        self.assertEqual(n.next(), 4)

    def test_factorial(self):
        env, Var = create_env()
        n = Var('n')
        fac = Var('fac')
        n.bind(Int(0).fby(n + Int(1)))
        fac.bind(Int(1).fby(fac * (n + Int(1))))

        self.assertEqual(fac.value, 1)
        self.assertEqual(fac.next(), 1)
        self.assertEqual(fac.next(), 2)
        self.assertEqual(fac.next(), 6)
        self.assertEqual(fac.next(), 24)

if __name__ == '__main__':
    unittest.main()