~magnusmorton/delayrepay

a323e722946a1b0e50aef73c2cb8db46cdebe485 — Magnus Morton 6 months ago 01a881e
removed unneeded stuff
1 files changed, 0 insertions(+), 91 deletions(-)

M delayrepay/cuda.py
M delayrepay/cuda.py => delayrepay/cuda.py +0 -91
@@ 1,7 1,5 @@
""" CUDA numpy array with delayed-evaluation semantics

Needs split up

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later


@@ 23,39 21,6 @@ import delayrepay.ir as ir

Shape = Tuple[int, int]

OPS = {
    "matmul": "@",
    "add": "+",
    "multiply": "*",
    "subtract": "-",
    "true_divide": "/",
}

FUNCS = {
    "power": "pow",
    "arctan2": "atan2",
    "absolute": "abs",
    "sin": "sin",
    "cos": "cos",
    "tan": "tan",
    "sqrt": "sqrt",
    "log": "log",
    # HACK
    "negative": "-",
    "exp": "exp",
    "tanh": "tanh",
    "sinh": "sinh",
    "cosh": "cosh",
}

ufunc_lookup = {
    "matmul": cupy.matmul,
    "add": cupy.add,
    "multiply": cupy.multiply,
    "subtract": cupy.subtract,
    "true_divide": cupy.true_divide,
}


class Visitor:
    """Visitor ABC"""


@@ 96,14 61,6 @@ class NumpyVisitor(Visitor):
        return self.visit(tree)


def is_matrix_matrix(left, right):
    return len(left) > 1 and len(right) > 1


def is_matrix_vector(left, right):
    return len(left) > 1 and len(right) == 1


InputDict = Dict[str, "BaseFragment"]




@@ 170,17 127,6 @@ class InputFragment(BaseFragment):
        return f"{self.name}"


# dtype_map = {np.dtype("float32"): "float",
#              np.dtype("float64"): "double",
#              np.dtype("int32"): "int",
#              np.dtype("int64"): "long"}

# dtype_map = {np.dtype("float32"): "f",
#              np.dtype("float64"): "",
#              np.dtype("int32"): "",
#              np.dtype("int64"): ""}


class ScalarFragment(BaseFragment):
    def __init__(self, val: ir.Scalar) -> None:
        super().__init__()


@@ 215,43 161,6 @@ def combine_inputs(*args: InputDict) -> InputDict:
    return ret


class PrettyPrinter(Visitor):
    def visit(self, node):
        if isinstance(node, list):
            return self.list_visit(node)
        print(type(node).__name__)
        self.visit(node.children)


class Fuser(Visitor):
    def __init__(self):
        super().__init__()
        self.seen = {}
        self.splits = []

    def fuse(self, node):

        self.visit(node)
        self.splits.append(node)
        return self.splits

    def visit(self, node) -> Shape:
        if isinstance(node, list):
            return self.list_visit(node)
        child_shapes = self.list_visit(node.children)
        new = []
        for child, shape in zip(node.children, child_shapes):
            if shape != node.shape and shape != (0,):
                new.append(ir.NPRef(child, node.shape))
                self.splits.append(child)
            else:
                new.append(child)

        node.children = new

        return node.shape


class CupyEmitter(Visitor):
    def __init__(self):
        super().__init__()