~magnusmorton/delayrepay

7ebd752f04ad2e2b8f136ceb4200a8a4dd9178f8 — Magnus Morton 6 months ago eaacb14
using split code now
5 files changed, 352 insertions(+), 354 deletions(-)

M delayrepay/__init__.py
M delayrepay/cuda.py
M delayrepay/delayarray.py
M delayrepay/ir.py
M delayrepay/random.py
M delayrepay/__init__.py => delayrepay/__init__.py +1 -1
@@ 17,7 17,7 @@ this program. If not, see <http://www.gnu.org/licenses/>.
# Copyright (C) 2020 by University of Edinburgh

#from numpy import *
from .cuarray import *
from .delayarray import *
import delayrepay.random

import cupy

M delayrepay/cuda.py => delayrepay/cuda.py +12 -11
@@ 17,7 17,8 @@ this program. If not, see <http://www.gnu.org/licenses/>.
from typing import Any, List, Dict, Tuple, Union
import cupy  # type: ignore

import delayrepay.ir as ir
np = cupy
fallback = cupy

Shape = Tuple[int, int]



@@ 115,7 116,7 @@ class Fragment(BaseFragment):


class InputFragment(BaseFragment):
    def __init__(self, name: str, arr: Union[ir.NPArray, ir.NPRef]) -> None:
    def __init__(self, name, arr) -> None:
        super().__init__()
        self.name = name
        self._inputs = {self.name: arr}


@@ 128,7 129,7 @@ class InputFragment(BaseFragment):


class ScalarFragment(BaseFragment):
    def __init__(self, val: ir.Scalar) -> None:
    def __init__(self, val) -> None:
        super().__init__()
        self.val = val.val
        self.dtype = val.dtype


@@ 179,7 180,7 @@ class CupyEmitter(Visitor):
            self.count += 1
        return visited

    def visit_BinaryNumpyEx(self, node: ir.BinaryNumpyEx) -> BaseFragment:
    def visit_BinaryNumpyEx(self, node) -> BaseFragment:
        op = node.to_op()
        left = self.visit(node.children[0])
        right = self.visit(node.children[1])


@@ 188,13 189,13 @@ class CupyEmitter(Visitor):
        stmts = left.stmts + right.stmts + [stmt]
        return Fragment(name, stmts, combine_inputs(left.inputs, right.inputs))

    def visit_UnaryFuncEx(self, node: ir.UnaryFuncEx) -> BaseFragment:
    def visit_UnaryFuncEx(self, node) -> BaseFragment:
        inner = self.visit(node.children[0])
        name = f"unfunc{self.count}"
        stmts = inner.stmts + [f"T {name} = {node.to_op()}({inner.ref()})"]
        return Fragment(name, stmts, inner.inputs)

    def visit_BinaryFuncEx(self, node: ir.BinaryFuncEx) -> BaseFragment:
    def visit_BinaryFuncEx(self, node) -> BaseFragment:
        op = node.to_op()
        left = self.visit(node.children[0])
        right = self.visit(node.children[1])


@@ 203,16 204,16 @@ class CupyEmitter(Visitor):
        stmts = left.stmts + right.stmts + [stmt]
        return Fragment(name, stmts, combine_inputs(left.inputs, right.inputs))

    def visit_NPArray(self, node: ir.NPArray) -> BaseFragment:
    def visit_NPArray(self, node) -> BaseFragment:
        return InputFragment(f"arr{self.count}", node)

    def visit_NPRef(self, node: ir.NPRef) -> BaseFragment:
    def visit_NPRef(self, node) -> BaseFragment:
        return InputFragment(f"ref{self.count}", node)

    def visit_Scalar(self, node: ir.Scalar) -> BaseFragment:
    def visit_Scalar(self, node) -> BaseFragment:
        return ScalarFragment(node)

    def visit_ReduceEx(self, node: ir.ReduceEx) -> BaseFragment:
    def visit_ReduceEx(self, node) -> BaseFragment:
        inner = self.visit(node.children[0])
        name = node.name
        op = node.to_op()


@@ 220,7 221,7 @@ class CupyEmitter(Visitor):
        return NotImplemented


def run_gpu(ex: ir.NumpyEx) -> cupy.array:
def run_gpu(ex) -> cupy.array:
    visitor = CupyEmitter()
    kerns = [visitor.visit(ex)]
    for kern in kerns:

M delayrepay/delayarray.py => delayrepay/delayarray.py +338 -12
@@ 17,10 17,10 @@ this program. If not, see <http://www.gnu.org/licenses/>.
# Copyright (C) 2020 by Univeristy of Edinburgh

from numbers import Number
from typing import Any, List, Tuple
import numpy as np  # type: ignore
import numpy.lib.mixins  # type: ignore
import delayrepay.cuda as backend
import delayrepay.ir as ir


def cast(func):


@@ 29,7 29,7 @@ def cast(func):
    def wrapper(*args, **kwargs):
        arr = func(*args, **kwargs)
        if not isinstance(arr, DelayArray):
            arr = ir.NPArray(arr)
            arr = NPArray(arr)
        return arr

    return wrapper


@@ 57,25 57,25 @@ class DelayArray(numpy.lib.mixins.NDArrayOperatorsMixin):
            if not isinstance(left, Number) and not isinstance(right, Number):
                if left.shape != right.shape:
                    if left.shape != (0,) and right.shape != (0,):
                        return backend.ufunc_lookup[ufunc.__name__](
                        return ufunc_lookup[ufunc.__name__](
                            left.__array__(), right.__array__()
                        )
        if ufunc.__name__ == "matmul":
            return self._dot(inputs, kwargs)
        # cls = func_to_numpy_ex(ufunc)
        args = [ir.arg_to_numpy_ex(arg) for arg in inputs]
        return ir.create_ex(ufunc, args)
        args = [arg_to_numpy_ex(arg) for arg in inputs]
        return create_ex(ufunc, args)

    def _dot_mv(self, args, kwargs):
        return ir.MVEx(args[0], args[1])
        return MVEx(args[0], args[1])

    def _dot_mm(self, args, kwargs):
        return ir.MMEx(args[0], args[1])
        return MMEx(args[0], args[1])

    @cast
    def _dot(self, args, kwargs):
        # scalar result dot
        args = [ir.arg_to_numpy_ex(arg) for arg in args]
        args = [arg_to_numpy_ex(arg) for arg in args]
        # if is_matrix_matrix(args[0].shape, args[1].shape):
        #     return self._dot_mm(args, kwargs)
        # if is_matrix_vector(args[0].shape, args[1].shape):


@@ 99,7 99,7 @@ class DelayArray(numpy.lib.mixins.NDArrayOperatorsMixin):
        return less(self, other)

    def dot(self, other, out=None):
        return backend.fallback.dot(self, other)
        return self._dot(other, out)

    def get(self):
        return self.__array__().get()


@@ 108,7 108,7 @@ class DelayArray(numpy.lib.mixins.NDArrayOperatorsMixin):
        self.__array__()

    def reshape(self, *args, **kwargs):
        return ir.NPArray(self.__array__().reshape(*args, **kwargs))
        return NPArray(self.__array__().reshape(*args, **kwargs))

    def __setitem__(self, key, item):
        arr = self.__array__()


@@ 145,6 145,332 @@ class DelayArray(numpy.lib.mixins.NDArrayOperatorsMixin):
        return repeat(self, *args, **kwargs)


Shape = Tuple[int, int]


OPS = {
    "matmul": "@",
    "add": "+",
    "multiply": "*",
    "subtract": "-",
    "true_divide": "/",
}


FUNCS = {
    "power": "pow",
    "arctan2": "atan2",
    "absolute": "abs",
    "sin": "sin",
    "cos": "cos",
    "tan": "tan",
    "sqrt": "sqrt",
    "log": "log",
    # HACK
    "negative": "-",
    "exp": "exp",
    "tanh": "tanh",
    "sinh": "sinh",
    "cosh": "cosh",
}

ufunc_lookup = {
    "matmul": backend.np.matmul,
    "add": backend.np.add,
    "multiply": backend.np.cupy.multiply,
    "subtract": backend.np.cupy.subtract,
    "true_divide": backend.np.true_divide,
}


def calc_shape(left, right, op=None):
    if left == (0,):
        return right
    if right == (0,):
        return left
    if op.__name__ in OPS:
        return left
    if op.__name__ == "dot":
        # for now
        if len(left) > 1 and len(right) > 1:
            return (left[0], right[1])
        elif len(left) > 1:
            return (left[0],)
        else:
            return (0,)
    else:
        return left


class Memoiser(type):
    """Metaclass implementing caching"""

    def __new__(meta, *args, **kwargs):
        cls = super(Memoiser, meta).__new__(meta, *args, **kwargs)
        meta._cache = {}
        return cls

    def __call__(cls, *args):
        if type(args[0]).__name__ == "ndarray":
            key = id(args[0])
        else:
            key = hash(args)
        if key not in cls._cache:
            Memoiser._cache[key] = super(Memoiser, cls).__call__(*args)
        return cls._cache[key]


def reset():
    # hacks
    Memoiser._cache.clear()


class NumpyEx(DelayArray, metaclass=Memoiser):
    children: List["NumpyEx"]
    """Numpy expression"""

    def __init__(self, children: List["NumpyEx"] = []):
        super().__init__()
        self.dtype = None
        self.children = children

    def __hash__(self):
        """
        Should work because of the Memoizer
        """
        return id(self)


class Funcable:
    def to_op(self):
        return OPS[self.func.__name__]


class ReduceEx(NumpyEx, Funcable):
    def __init__(self, func, arg):
        super().__init__(children=[arg])
        self.func = func
        self.shape = (0,)

    # func: np.ufunc
    # arg: NumpyEx


class UnaryFuncEx(NumpyEx, Funcable):
    def __init__(self, func, arg):
        super().__init__(children=[arg])
        self.func = func
        self.shape = arg.shape
        self.dtype = arg.dtype

    def to_op(self):
        return FUNCS[self.func.__name__]


class BinaryFuncEx(NumpyEx):
    def __init__(self, func, left, right):
        super().__init__(children=[left, right])
        self.func = func
        self.shape = calc_shape(left.shape, right.shape, func)
        self.dtype = calc_type(left, right)

    def to_op(self):
        return FUNCS[self.func.__name__]


def pow_ex(func, left, right):
    if not isinstance(right.val, int):
        return BinaryFuncEx(func, left, right)
    ex = left
    for i in range(right.val - 1):
        # will give odd expression tree, but OK
        ex = BinaryNumpyEx(np.multiply, ex, left)

    return ex


def create_ex(func, args):
    if func.__name__ in OPS:
        return BinaryNumpyEx(func, *args)
    if func.__name__ == "square":
        return BinaryNumpyEx(np.multiply, args[0], args[0])
    if len(args) == 1:
        return UnaryFuncEx(func, *args)
    if func.__name__ == "power":
        return pow_ex(func, *args)
    return BinaryFuncEx(func, *args)


class BinaryNumpyEx(NumpyEx, Funcable):
    """Binary numpy expression"""

    def __init__(self, func, left, right):
        super().__init__(children=[left, right])
        self.func = func
        self.shape = calc_shape(left.shape, right.shape, func)
        self.dtype = calc_type(left, right)


class MMEx(NumpyEx, Funcable):
    # arg1: NumpyEx
    # arg2: NumpyEx
    def __init__(self, arg1, arg2):
        super().__init__()
        self.arg1 = arg1
        self.arg2 = arg2
        self.shape = calc_shape(arg1.shape, arg2.shape, np.dot)


class MVEx(NumpyEx, Funcable):
    # arg1: NumpyEx
    # arg2: NumpyEx
    def __init__(self, arg1, arg2):
        super().__init__()
        self.arg1 = arg1
        self.arg2 = arg2
        self.shape = calc_shape(arg1.shape, arg2.shape, np.dot)


class DotEx(NumpyEx, Funcable):
    def __init__(self, left, right):
        super().__init__()
        self.arg1 = left
        self.arg2 = right
        self.shape = calc_shape(left.shape, right.shape, np.dot)
        self._inshape = left.shape


class NPArray(NumpyEx):
    """ndarray"""

    def __init__(self, array):
        super().__init__()
        self.array = array
        self.shape = array.shape
        self.dtype = array.dtype

    def __hash__(self):
        return id(self.array)

    def __eq__(self, other):
        try:
            return self.array is other.array
        except AttributeError:
            return False

    def astype(self, *args, **kwargs):
        old = self.array
        cast_arr = self.array.astype(*args, **kwargs)
        del NPArray._cache[id(old)]
        NPArray._cache[id(cast_arr)] = self
        self.array = cast_arr
        self.dtype = cast_arr.dtype
        return self


class NPRef(NumpyEx):
    """Only for when breaking dependency chains for fusion"""

    def __init__(self, node: NumpyEx, shape: Shape):
        super().__init__()
        self.ref = node
        self.children = []
        self.shape = shape

    @property
    def array(self):
        return self.ref.array


class Scalar(NumpyEx):
    """a scalar"""

    # val: Number
    def __init__(self, val):
        super().__init__()
        self.val = val
        self.shape = (0,)

    def __hash__(self):
        return hash(self.val)


class Visitor:
    """Visitor ABC"""

    def visit(self, node) -> Any:
        """Visit a node."""
        if isinstance(node, list):
            visitor = self.list_visit
        else:
            method = "visit_" + node.__class__.__name__
            visitor = getattr(self, method, self.default_visit)
        return visitor(node)

    def list_visit(self, lst, **kwargs):
        return [self.visit(node) for node in lst]

    def default_visit(self, node):
        return node


class NumpyVisitor(Visitor):
    """Visits Numpy Expression"""

    def __init__(self):
        self.visits = 0

    def visit(self, node):
        """Visit a node."""
        self.visits += 1
        return super(NumpyVisitor, self).visit(node)

    def visit_BinaryExpression(self, node):
        return node

    def walk(self, tree):
        """ top-level walk of tree"""
        self.visits = 0
        return self.visit(tree)


def is_matrix_matrix(left, right):
    return len(left) > 1 and len(right) > 1


def is_matrix_vector(left, right):
    return len(left) > 1 and len(right) == 1


def calc_type(node1: NumpyEx, node2: NumpyEx) -> np.dtype:
    if node1.dtype is not None:
        node2.dtype = node1.dtype
        return node1.dtype
    node1.dtype = node2.dtype
    return node2.dtype


def arg_to_numpy_ex(arg: Any) -> NumpyEx:
    import cupy
    if isinstance(arg, DelayArray):
        return arg
    elif isinstance(arg, Number):
        return Scalar(arg)
    elif isinstance(arg, cupy.core.core.ndarray) or isinstance(arg, np.ndarray):
        return NPArray(arg)
    else:
        print(type(arg))
        raise NotImplementedError


class PrettyPrinter(Visitor):
    def visit(self, node):
        if isinstance(node, list):
            return self.list_visit(node)
        print(type(node).__name__)
        self.visit(node.children)


HANDLED_FUNCTIONS = {}




@@ 158,10 484,10 @@ def implements(np_function):

@implements(np.diag)
def diag(arr, k=0):
    if isinstance(arr.ex, ir.NPArray):
    if isinstance(arr.ex, NPArray):
        arr._ndarray = np.ascontiguousarray(np.diag(arr._ndarray, k))
        assert arr._ndarray.flags["C_CONTIGUOUS"]
        arr.ex = ir.NPArray(arr._ndarray)
        arr.ex = NPArray(arr._ndarray)
        return arr
    else:
        return NotImplemented

M delayrepay/ir.py => delayrepay/ir.py +0 -329
@@ 14,332 14,3 @@ this program. If not, see <http://www.gnu.org/licenses/>.
"""
# Copyright (C) 2020 by Univeristy of Edinburgh

from numbers import Number
from typing import Tuple, List, Any
import cupy  # type: ignore
import numpy as np  # type: ignore
import numpy.lib.mixins  # type: ignore
from .delayarray import DelayArray


Shape = Tuple[int, int]

OPS = {
    "matmul": "@",
    "add": "+",
    "multiply": "*",
    "subtract": "-",
    "true_divide": "/",
}

FUNCS = {
    "power": "pow",
    "arctan2": "atan2",
    "absolute": "abs",
    "sin": "sin",
    "cos": "cos",
    "tan": "tan",
    "sqrt": "sqrt",
    "log": "log",
    # HACK
    "negative": "-",
    "exp": "exp",
    "tanh": "tanh",
    "sinh": "sinh",
    "cosh": "cosh",
}

ufunc_lookup = {
    "matmul": cupy.matmul,
    "add": cupy.add,
    "multiply": cupy.multiply,
    "subtract": cupy.subtract,
    "true_divide": cupy.true_divide,
}


def calc_shape(left, right, op=None):
    if left == (0,):
        return right
    if right == (0,):
        return left
    if op.__name__ in OPS:
        return left
    if op.__name__ == "dot":
        # for now
        if len(left) > 1 and len(right) > 1:
            return (left[0], right[1])
        elif len(left) > 1:
            return (left[0],)
        else:
            return (0,)
    else:
        return left


class Memoiser(type):
    """Metaclass implementing caching"""

    def __new__(meta, *args, **kwargs):
        cls = super(Memoiser, meta).__new__(meta, *args, **kwargs)
        meta._cache = {}
        return cls

    def __call__(cls, *args):
        if type(args[0]).__name__ == "ndarray":
            key = id(args[0])
        else:
            key = hash(args)
        if key not in cls._cache:
            Memoiser._cache[key] = super(Memoiser, cls).__call__(*args)
        return cls._cache[key]


def reset():
    # hacks
    Memoiser._cache.clear()


class NumpyEx(DelayArray, metaclass=Memoiser):
    children: List["NumpyEx"]
    """Numpy expression"""

    def __init__(self, children: List["NumpyEx"] = []):
        super().__init__()
        self.dtype = None
        self.children = children

    def __hash__(self):
        """
        Should work because of the Memoizer
        """
        return id(self)


class Funcable:
    def to_op(self):
        return OPS[self.func.__name__]


class ReduceEx(NumpyEx, Funcable):
    def __init__(self, func, arg):
        super().__init__(children=[arg])
        self.func = func
        self.shape = (0,)

    # func: np.ufunc
    # arg: NumpyEx


class UnaryFuncEx(NumpyEx, Funcable):
    def __init__(self, func, arg):
        super().__init__(children=[arg])
        self.func = func
        self.shape = arg.shape
        self.dtype = arg.dtype

    def to_op(self):
        return FUNCS[self.func.__name__]


class BinaryFuncEx(NumpyEx):
    def __init__(self, func, left, right):
        super().__init__(children=[left, right])
        self.func = func
        self.shape = calc_shape(left.shape, right.shape, func)
        self.dtype = calc_type(left, right)

    def to_op(self):
        return FUNCS[self.func.__name__]


def pow_ex(func, left, right):
    if not isinstance(right.val, int):
        return BinaryFuncEx(func, left, right)
    ex = left
    for i in range(right.val - 1):
        # will give odd expression tree, but OK
        ex = BinaryNumpyEx(np.multiply, ex, left)

    return ex


def create_ex(func, args):
    if func.__name__ in OPS:
        return BinaryNumpyEx(func, *args)
    if func.__name__ == "square":
        return BinaryNumpyEx(np.multiply, args[0], args[0])
    if len(args) == 1:
        return UnaryFuncEx(func, *args)
    if func.__name__ == "power":
        return pow_ex(func, *args)
    return BinaryFuncEx(func, *args)


class BinaryNumpyEx(NumpyEx, Funcable):
    """Binary numpy expression"""

    def __init__(self, func, left, right):
        super().__init__(children=[left, right])
        self.func = func
        self.shape = calc_shape(left.shape, right.shape, func)
        self.dtype = calc_type(left, right)


class MMEx(NumpyEx, Funcable):
    # arg1: NumpyEx
    # arg2: NumpyEx
    def __init__(self, arg1, arg2):
        super().__init__()
        self.arg1 = arg1
        self.arg2 = arg2
        self.shape = calc_shape(arg1.shape, arg2.shape, np.dot)


class MVEx(NumpyEx, Funcable):
    # arg1: NumpyEx
    # arg2: NumpyEx
    def __init__(self, arg1, arg2):
        super().__init__()
        self.arg1 = arg1
        self.arg2 = arg2
        self.shape = calc_shape(arg1.shape, arg2.shape, np.dot)


class DotEx(NumpyEx, Funcable):
    def __init__(self, left, right):
        super().__init__()
        self.arg1 = left
        self.arg2 = right
        self.shape = calc_shape(left.shape, right.shape, np.dot)
        self._inshape = left.shape


class NPArray(NumpyEx):
    """ndarray"""

    def __init__(self, array):
        super().__init__()
        self.array = array
        self.shape = array.shape
        self.dtype = array.dtype

    def __hash__(self):
        return id(self.array)

    def __eq__(self, other):
        try:
            return self.array is other.array
        except AttributeError:
            return False

    def astype(self, *args, **kwargs):
        old = self.array
        cast_arr = self.array.astype(*args, **kwargs)
        del NPArray._cache[id(old)]
        NPArray._cache[id(cast_arr)] = self
        self.array = cast_arr
        self.dtype = cast_arr.dtype
        return self


class NPRef(NumpyEx):
    """Only for when breaking dependency chains for fusion"""

    def __init__(self, node: NumpyEx, shape: Shape):
        super().__init__()
        self.ref = node
        self.children = []
        self.shape = shape

    @property
    def array(self):
        return self.ref.array


class Scalar(NumpyEx):
    """a scalar"""

    # val: Number
    def __init__(self, val):
        super().__init__()
        self.val = val
        self.shape = (0,)

    def __hash__(self):
        return hash(self.val)


class Visitor:
    """Visitor ABC"""

    def visit(self, node) -> Any:
        """Visit a node."""
        if isinstance(node, list):
            visitor = self.list_visit
        else:
            method = "visit_" + node.__class__.__name__
            visitor = getattr(self, method, self.default_visit)
        return visitor(node)

    def list_visit(self, lst, **kwargs):
        return [self.visit(node) for node in lst]

    def default_visit(self, node):
        return node


class NumpyVisitor(Visitor):
    """Visits Numpy Expression"""

    def __init__(self):
        self.visits = 0

    def visit(self, node):
        """Visit a node."""
        self.visits += 1
        return super(NumpyVisitor, self).visit(node)

    def visit_BinaryExpression(self, node):
        return node

    def walk(self, tree):
        """ top-level walk of tree"""
        self.visits = 0
        return self.visit(tree)


def is_matrix_matrix(left, right):
    return len(left) > 1 and len(right) > 1


def is_matrix_vector(left, right):
    return len(left) > 1 and len(right) == 1


def calc_type(node1: NumpyEx, node2: NumpyEx) -> np.dtype:
    if node1.dtype is not None:
        node2.dtype = node1.dtype
        return node1.dtype
    node1.dtype = node2.dtype
    return node2.dtype


def arg_to_numpy_ex(arg: Any) -> NumpyEx:
    if isinstance(arg, DelayArray):
        return arg
    elif isinstance(arg, Number):
        return Scalar(arg)
    elif isinstance(arg, cupy.core.core.ndarray) or isinstance(arg, np.ndarray):
        return NPArray(arg)
    else:
        print(type(arg))
        raise NotImplementedError


class PrettyPrinter(Visitor):
    def visit(self, node):
        if isinstance(node, list):
            return self.list_visit(node)
        print(type(node).__name__)
        self.visit(node.children)

M delayrepay/random.py => delayrepay/random.py +1 -1
@@ 16,7 16,7 @@ this program. If not, see <http://www.gnu.org/licenses/>.

import cupy.random
#from numpy.random import *
from .cuarray import cast
from .delayarray import cast

rand = cast(cupy.random.rand)
randn = cast(cupy.random.randn)