~magnusmorton/delayrepay

ee1c9c9fa22dfe92d21a70dab735f059d3cf4d6c — Magnus Morton 5 months ago ca1b02c
extracted visitor into own module

added visitor.py
4 files changed, 54 insertions(+), 93 deletions(-)

M delayrepay/cuda.py
M delayrepay/delayarray.py
M delayrepay/fft.py
A delayrepay/visitor.py
M delayrepay/cuda.py => delayrepay/cuda.py +3 -41
@@ 14,54 14,16 @@ this program. If not, see <http://www.gnu.org/licenses/>.
"""
# Copyright (C) 2020 by Univeristy of Edinburgh

from typing import Any, List, Dict, Tuple, Union
from typing import List, Dict, Tuple
import cupy  # type: ignore

from .visitor import Visitor

np = cupy
fallback = cupy

Shape = Tuple[int, int]


class Visitor:
    """Visitor ABC"""

    def visit(self, node) -> Any:
        """Visit a node."""
        if isinstance(node, list):
            visitor = self.list_visit
        else:
            method = "visit_" + node.__class__.__name__
            visitor = getattr(self, method, self.default_visit)
        return visitor(node)

    def list_visit(self, lst, **kwargs):
        return [self.visit(node) for node in lst]

    def default_visit(self, node):
        return node


class NumpyVisitor(Visitor):
    """Visits Numpy Expression"""

    def __init__(self):
        self.visits = 0

    def visit(self, node):
        """Visit a node."""
        self.visits += 1
        return super(NumpyVisitor, self).visit(node)

    def visit_BinaryExpression(self, node):
        return node

    def walk(self, tree):
        """ top-level walk of tree"""
        self.visits = 0
        return self.visit(tree)


InputDict = Dict[str, "BaseFragment"]



M delayrepay/delayarray.py => delayrepay/delayarray.py +0 -47
@@ 395,45 395,6 @@ class Scalar(NumpyEx):
        return hash(self.val)


class Visitor:
    """Visitor ABC"""

    def visit(self, node) -> Any:
        """Visit a node."""
        if isinstance(node, list):
            visitor = self.list_visit
        else:
            method = "visit_" + node.__class__.__name__
            visitor = getattr(self, method, self.default_visit)
        return visitor(node)

    def list_visit(self, lst, **kwargs):
        return [self.visit(node) for node in lst]

    def default_visit(self, node):
        return node


class NumpyVisitor(Visitor):
    """Visits Numpy Expression"""

    def __init__(self):
        self.visits = 0

    def visit(self, node):
        """Visit a node."""
        self.visits += 1
        return super(NumpyVisitor, self).visit(node)

    def visit_BinaryExpression(self, node):
        return node

    def walk(self, tree):
        """ top-level walk of tree"""
        self.visits = 0
        return self.visit(tree)


def is_matrix_matrix(left, right):
    return len(left) > 1 and len(right) > 1



@@ 463,14 424,6 @@ def arg_to_numpy_ex(arg: Any) -> NumpyEx:
        raise NotImplementedError


class PrettyPrinter(Visitor):
    def visit(self, node):
        if isinstance(node, list):
            return self.list_visit(node)
        print(type(node).__name__)
        self.visit(node.children)


HANDLED_FUNCTIONS = {}



M delayrepay/fft.py => delayrepay/fft.py +3 -5
@@ 16,12 16,10 @@ this program. If not, see <http://www.gnu.org/licenses/>.


import cupy.fft
from .cuarray import DelayArray
from .delayarray import DelayArray

def fft(self, *args, **kwargs):

    print(args)
def fft(self, *args, **kwargs):
    nargs = [arg.__array__() if isinstance(arg, DelayArray) else arg
            for arg in args]
    print(nargs)
             for arg in args]
    return cupy.fft.fft(*nargs, **kwargs)

A delayrepay/visitor.py => delayrepay/visitor.py +48 -0
@@ 0,0 1,48 @@
from typing import Any


class Visitor:
    """Visitor ABC"""

    def visit(self, node) -> Any:
        """Visit a node."""
        if isinstance(node, list):
            visitor = self.list_visit
        else:
            method = "visit_" + node.__class__.__name__
            visitor = getattr(self, method, self.default_visit)
        return visitor(node)

    def list_visit(self, lst, **kwargs):
        return [self.visit(node) for node in lst]

    def default_visit(self, node):
        return node


class NumpyVisitor(Visitor):
    """Visits Numpy Expression"""

    def __init__(self):
        self.visits = 0

    def visit(self, node):
        """Visit a node."""
        self.visits += 1
        return super(NumpyVisitor, self).visit(node)

    def visit_BinaryExpression(self, node):
        return node

    def walk(self, tree):
        """ top-level walk of tree"""
        self.visits = 0
        return self.visit(tree)


class PrettyPrinter(Visitor):
    def visit(self, node):
        if isinstance(node, list):
            return self.list_visit(node)
        print(type(node).__name__)
        self.visit(node.children)