01a881e4d047f3614f73293fababebeb4a302b0d — Magnus Morton 6 months ago f28d504
removed unneeded stuff. fixed refs
1 files changed, 15 insertions(+), 562 deletions(-)

M delayrepay/cuda.py
M delayrepay/cuda.py => delayrepay/cuda.py +15 -562
@@ 16,11 16,10 @@ this program. If not, see <http://www.gnu.org/licenses/>.
# Copyright (C) 2020 by Univeristy of Edinburgh

from numbers import Number
from typing import Any, List, Dict, Tuple, Optional, Union, Set
from typing import Any, List, Dict, Tuple, Union
import cupy  # type: ignore
import numpy as np  # type: ignore
import numpy.lib.mixins  # type: ignore

import delayrepay.ir as ir

Shape = Tuple[int, int]

@@ 58,338 57,6 @@ ufunc_lookup = {

def cast(func):
    """cast to Delay array decorator"""

    def wrapper(*args, **kwargs):
        arr = func(*args, **kwargs)
        if not isinstance(arr, DelayArray):
            arr = NPArray(arr)
        return arr

    return wrapper

class DelayArray(numpy.lib.mixins.NDArrayOperatorsMixin):
    def __init__(self, *args, **kwargs):
        self._memo = None

    def __repr__(self):
        return str(self.__array__())

    def __array__(self):
        # return NumpyFunction(self.ex)()
            return self.array
        except AttributeError:
            self.array = run_gpu(self)
            return self.array

    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
        if len(inputs) > 1:
            left = inputs[0]
            right = inputs[1]
            if not isinstance(left, Number) and not isinstance(right, Number):
                if left.shape != right.shape:
                    if left.shape != (0,) and right.shape != (0,):
                        return ufunc_lookup[ufunc.__name__](
                            left.__array__(), right.__array__()
        if ufunc.__name__ == "matmul":
            return self._dot(inputs, kwargs)
        # cls = func_to_numpy_ex(ufunc)
        args = [arg_to_numpy_ex(arg) for arg in inputs]
        return create_ex(ufunc, args)

    def _dot_mv(self, args, kwargs):
        return MVEx(args[0], args[1])

    def _dot_mm(self, args, kwargs):
        return MMEx(args[0], args[1])

    def _dot(self, args, kwargs):
        # scalar result dot
        args = [arg_to_numpy_ex(arg) for arg in args]
        # if is_matrix_matrix(args[0].shape, args[1].shape):
        #     return self._dot_mm(args, kwargs)
        # if is_matrix_vector(args[0].shape, args[1].shape):
        #     return self._dot_mv(args, kwargs)

        left = args[0].__array__()
        right = args[1].__array__()
        return cupy.dot(left, right)

    def __array_function__(self, func, types, args, kwargs):
        if func.__name__ == "dot":
            return self._dot(args, kwargs)
        return HANDLED_FUNCTIONS[func](*args, **kwargs)

    def __gt__(self, other):
        return greater(self, other)

    def __lt__(self, other):
        return less(self, other)

    def dot(self, other, out=None):
        return np.dot(self, other)

    def get(self):
        return self.__array__().get()

    def run(self):

    def reshape(self, *args, **kwargs):
        return NPArray(self.__array__().reshape(*args, **kwargs))

    def __setitem__(self, key, item):
        arr = self.__array__()
        if isinstance(key, DelayArray):
            key = key.__array__()
        if isinstance(item, DelayArray):
            item = item.__array__()

        arr[key] = item

    def __getitem__(self, key):
        if isinstance(key, DelayArray):
            key = key.__array__()
        arr = self.__array__()
        return arr[key]

    def var(self, *args, **kwargs):
        return np.var(self, *args, **kwargs)

    def sum(self, *args, **kwargs):
        return np.sum(self, *args, **kwargs)

    def __len__(self):
        return self.shape[0]

    def T(self):
        if len(self.shape) == 1:
            return self
        return np.transpose(self)

    def repeat(self, *args, **kwargs):
        return repeat(self, *args, **kwargs)

def calc_shape(left, right, op=None):
    if left == (0,):
        return right
    if right is (0,):
        return left
    if op.__name__ in OPS:
        return left
    if op.__name__ == "dot":
        # for now
        if len(left) > 1 and len(right) > 1:
            return (left[0], right[1])
        elif len(left) > 1:
            return (left[0],)
            return (0,)
        return left

class Memoiser(type):
    """Metaclass implementing caching"""

    def __new__(meta, *args, **kwargs):
        cls = super(Memoiser, meta).__new__(meta, *args, **kwargs)
        meta._cache = {}
        return cls

    def __call__(cls, *args):
        if type(args[0]).__name__ == "ndarray":
            key = id(args[0])
            key = hash(args)
        if key not in cls._cache:
            Memoiser._cache[key] = super(Memoiser, cls).__call__(*args)
        return cls._cache[key]

def reset():
    # hacks

class NumpyEx(DelayArray, metaclass=Memoiser):
    children: List["NumpyEx"]
    """Numpy expression"""

    def __init__(self, children: List["NumpyEx"] = []):
        self.dtype = None
        self.children = children

    def __hash__(self):
        Should work because of the Memoizer
        return id(self)

class Funcable:
    def to_op(self):
        return OPS[self.func.__name__]

class ReduceEx(NumpyEx, Funcable):
    def __init__(self, func, arg):
        self.func = func
        self.shape = (0,)

    # func: np.ufunc
    # arg: NumpyEx

class UnaryFuncEx(NumpyEx, Funcable):
    def __init__(self, func, arg):
        self.func = func
        self.shape = arg.shape
        self.dtype = arg.dtype

    def to_op(self):
        return FUNCS[self.func.__name__]

class BinaryFuncEx(NumpyEx):
    def __init__(self, func, left, right):
        super().__init__(children=[left, right])
        self.func = func
        self.shape = calc_shape(left.shape, right.shape, func)
        self.dtype = calc_type(left, right)

    def to_op(self):
        return FUNCS[self.func.__name__]

def pow_ex(func, left, right):
    if not isinstance(right.val, int):
        return BinaryFuncEx(func, left, right)
    ex = left
    for i in range(right.val - 1):
        # will give odd expression tree, but OK
        ex = BinaryNumpyEx(multiply, ex, left)

    return ex

def create_ex(func, args):
    if func.__name__ in OPS:
        return BinaryNumpyEx(func, *args)
    if func.__name__ == "square":
        return BinaryNumpyEx(multiply, args[0], args[0])
    if len(args) == 1:
        return UnaryFuncEx(func, *args)
    if func.__name__ == "power":
        return pow_ex(func, *args)
    return BinaryFuncEx(func, *args)

class BinaryNumpyEx(NumpyEx, Funcable):
    """Binary numpy expression"""

    def __init__(self, func, left, right):
        super().__init__(children=[left, right])
        self.func = func
        self.shape = calc_shape(left.shape, right.shape, func)
        self.dtype = calc_type(left, right)

class MMEx(NumpyEx, Funcable):
    # arg1: NumpyEx
    # arg2: NumpyEx
    def __init__(self, arg1, arg2):
        self.arg1 = arg1
        self.arg2 = arg2
        self.shape = calc_shape(arg1.shape, arg2.shape, np.dot)

class MVEx(NumpyEx, Funcable):
    # arg1: NumpyEx
    # arg2: NumpyEx
    def __init__(self, arg1, arg2):
        self.arg1 = arg1
        self.arg2 = arg2
        self.shape = calc_shape(arg1.shape, arg2.shape, np.dot)

class DotEx(NumpyEx, Funcable):
    def __init__(self, left, right):
        self.arg1 = left
        self.arg2 = right
        self.shape = calc_shape(left.shape, right.shape, np.dot)
        self._inshape = left.shape

class NPArray(NumpyEx):

    def __init__(self, array):
        self.array = array
        self.shape = array.shape
        self.dtype = array.dtype

    def __hash__(self):
        return id(self.array)

    def __eq__(self, other):
            return self.array is other.array
        except AttributeError:
            return False

    def astype(self, *args, **kwargs):
        old = self.array
        cast_arr = self.array.astype(*args, **kwargs)
        del NPArray._cache[id(old)]
        NPArray._cache[id(cast_arr)] = self
        self.array = cast_arr
        self.dtype = cast_arr.dtype
        return self

class NPRef(NumpyEx):
    """Only for when breaking dependency chains for fusion"""

    def __init__(self, node: NumpyEx, shape: Shape):
        self.ref = node
        self.children = []
        self.shape = shape

    def array(self):
        return self.ref.array

class Scalar(NumpyEx):
    """a scalar"""

    # val: Number
    def __init__(self, val):
        self.val = val
        self.shape = (0,)

    def __hash__(self):
        return hash(self.val)

class Visitor:
    """Visitor ABC"""

@@ 437,220 104,6 @@ def is_matrix_vector(left, right):
    return len(left) > 1 and len(right) == 1

# def calc_type(func, type1, type2):
#     if 'float64' in (type1, type2):
#         return 'float64'
#     elif 'float32' in (type1, type2):
#         return 'float32'
#     elif 'int64' in (type1, type2):
#         return 'int64'
#     else:
#         return type1

def calc_type(node1: NumpyEx, node2: NumpyEx) -> np.dtype:
    if node1.dtype is not None:
        node2.dtype = node1.dtype
        return node1.dtype
    node1.dtype = node2.dtype
    return node2.dtype


def implements(np_function):
    "Register an __array_function__ implementation for DiagonalArray objects."

    def decorator(func):
        HANDLED_FUNCTIONS[np_function] = func
        return func

    return decorator

def arg_to_numpy_ex(arg: Any) -> NumpyEx:
    from numbers import Number

    if isinstance(arg, DelayArray):
        return arg
    elif isinstance(arg, Number):
        return Scalar(arg)
    elif isinstance(arg, cupy.core.core.ndarray) or isinstance(arg, np.ndarray):
        return NPArray(arg)
        raise NotImplementedError

# def func_to_numpy_ex(func):
#     return {
#         'matmul': Matmul,
#         'add': Add,
#         'multiply': Multiply
#         }[func.__name__]

def diag(arr, k=0):
    if isinstance(arr.ex, NPArray):
        arr._ndarray = np.ascontiguousarray(np.diag(arr._ndarray, k))
        assert arr._ndarray.flags["C_CONTIGUOUS"]
        arr.ex = NPArray(arr._ndarray)
        return arr
        return NotImplemented

def diagflat(arr, k=0):
    # keep it simple for now
    return np.diagflat(np.asarray(arr, order="C"))

# @implements(np.sum)
# def sum(arr, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=None):
#    print("BLAH")
#    return ReduceEx(np.add, arr)

def var(arr, *args, **kwargs):
    return cupy.var(arr.__array__(), *args, **kwargs)

def sum(arr, *args, **kwargs):
    return cupy.sum(arr.__array__(), *args, **kwargs)

def transpose(arr, *args, **kwargs):
    return cupy.transpose(arr.__array__(), *args, **kwargs)

def roll(arr, *args, **kwargs):
    return cupy.roll(arr.__array__(), *args, **kwargs)

def max(arr, *args, **kwargs):
    return cupy.max(arr.__array__(), *args, **kwargs)

def maximum(arr, *args, **kwargs):
    return cupy.maximum(arr.__array__(), *args, **kwargs)

def average(arr, *args, **kwargs):
    return cupy.average(arr.__array__(), *args, **kwargs)

def repeat(arr, *args, **kwargs):
    return cupy.repeat(arr.__array__(), *args, **kwargs)

def cumsum(arr, *args, **kwargs):
    return cupy.cumsum(arr.__array__(), *args, **kwargs)

def greater(arr1, arr2, *args, **kwargs):
    return cupy.greater(arr1.__array__(), arr2, *args, **kwargs)

def less(arr1, arr2, *args, **kwargs):
    return cupy.less(arr1.__array__(), arr2, *args, **kwargs)

# sum = cast(cupy.sum)
add = np.add
multiply = np.multiply
dot = np.dot
cos = np.cos
sin = np.sin
tan = np.tan
tanh = np.tanh
sinh = np.sinh
cosh = np.cosh
arctan2 = np.arctan2
subtract = np.subtract
exp = np.exp
log = np.log
power = np.power
sqrt = np.sqrt
square = np.square
abs = np.abs
newaxis = cupy.newaxis

# dtypes etc.
double = np.double
float32 = np.float32
uint32 = np.uint32

# Ones and zeros
empty = cast(cupy.empty)
empty_like = cast(cupy.empty_like)
eye = cast(cupy.eye)
identity = cast(cupy.identity)
ones = cast(cupy.ones)
ones_like = cast(cupy.ones_like)
zeros = cast(cupy.zeros)
zeros_like = cast(cupy.zeros_like)
full = cast(cupy.full)
full_like = cast(cupy.full_like)

def tile(arr, *args, **kwargs):

    if isinstance(arr, DelayArray):
        temp = np.array(arr.__array__().get())
    return cupy.tile(temp, *args, **kwargs)

# From existing data

array = cast(cupy.array)
asarray = cast(cupy.asarray)
asanyarray = cast(cupy.asanyarray)
ascontiguousarray = cast(cupy.ascontiguousarray)
asmatrix = cast(np.asmatrix)
copy = cast(cupy.copy)
frombuffer = cast(np.frombuffer)
fromfile = cast(np.fromfile)
fromfunction = cast(np.fromfunction)
fromiter = cast(np.fromiter)
fromstring = cast(np.fromstring)
loadtxt = cast(np.loadtxt)

# Numerical ranges
arange = cast(cupy.arange)
linspace = cast(cupy.linspace)
logspace = cast(cupy.logspace)
geomspace = cast(np.geomspace)

# Building matrices
tri = cast(cupy.tri)
tril = cast(cupy.tril)
triu = cast(cupy.triu)
vander = cast(np.vander)

InputDict = Dict[str, "BaseFragment"]

@@ 697,7 150,7 @@ class Fragment(BaseFragment):
        inargs = [f"T {arg}" for arg in self.kernel_args]
        kern = cupy.ElementwiseKernel(
            f"T out",
            "T out",
            f"{body};\nout = {self.name}",

@@ 705,7 158,7 @@ class Fragment(BaseFragment):

class InputFragment(BaseFragment):
    def __init__(self, name: str, arr: Union[NPArray, NPRef]) -> None:
    def __init__(self, name: str, arr: Union[ir.NPArray, ir.NPRef]) -> None:
        self.name = name
        self._inputs = {self.name: arr}

@@ 729,7 182,7 @@ class InputFragment(BaseFragment):

class ScalarFragment(BaseFragment):
    def __init__(self, val: Scalar) -> None:
    def __init__(self, val: ir.Scalar) -> None:
        self.val = val.val
        self.dtype = val.dtype

@@ 789,7 242,7 @@ class Fuser(Visitor):
        new = []
        for child, shape in zip(node.children, child_shapes):
            if shape != node.shape and shape != (0,):
                new.append(NPRef(child, node.shape))
                new.append(ir.NPRef(child, node.shape))

@@ 817,7 270,7 @@ class CupyEmitter(Visitor):
            self.count += 1
        return visited

    def visit_BinaryNumpyEx(self, node: BinaryNumpyEx) -> BaseFragment:
    def visit_BinaryNumpyEx(self, node: ir.BinaryNumpyEx) -> BaseFragment:
        op = node.to_op()
        left = self.visit(node.children[0])
        right = self.visit(node.children[1])

@@ 826,13 279,13 @@ class CupyEmitter(Visitor):
        stmts = left.stmts + right.stmts + [stmt]
        return Fragment(name, stmts, combine_inputs(left.inputs, right.inputs))

    def visit_UnaryFuncEx(self, node: UnaryFuncEx) -> BaseFragment:
    def visit_UnaryFuncEx(self, node: ir.UnaryFuncEx) -> BaseFragment:
        inner = self.visit(node.children[0])
        name = f"unfunc{self.count}"
        stmts = inner.stmts + [f"T {name} = {node.to_op()}({inner.ref()})"]
        return Fragment(name, stmts, inner.inputs)

    def visit_BinaryFuncEx(self, node: BinaryFuncEx) -> BaseFragment:
    def visit_BinaryFuncEx(self, node: ir.BinaryFuncEx) -> BaseFragment:
        op = node.to_op()
        left = self.visit(node.children[0])
        right = self.visit(node.children[1])

@@ 841,16 294,16 @@ class CupyEmitter(Visitor):
        stmts = left.stmts + right.stmts + [stmt]
        return Fragment(name, stmts, combine_inputs(left.inputs, right.inputs))

    def visit_NPArray(self, node: NPArray) -> BaseFragment:
    def visit_NPArray(self, node: ir.NPArray) -> BaseFragment:
        return InputFragment(f"arr{self.count}", node)

    def visit_NPRef(self, node: NPRef) -> BaseFragment:
    def visit_NPRef(self, node: ir.NPRef) -> BaseFragment:
        return InputFragment(f"ref{self.count}", node)

    def visit_Scalar(self, node: Scalar) -> BaseFragment:
    def visit_Scalar(self, node: ir.Scalar) -> BaseFragment:
        return ScalarFragment(node)

    def visit_ReduceEx(self, node: ReduceEx) -> BaseFragment:
    def visit_ReduceEx(self, node: ir.ReduceEx) -> BaseFragment:
        inner = self.visit(node.children[0])
        name = node.name
        op = node.to_op()

@@ 858,7 311,7 @@ class CupyEmitter(Visitor):
        return NotImplemented

def run_gpu(ex: NumpyEx) -> cupy.array:
def run_gpu(ex: ir.NumpyEx) -> cupy.array:
    visitor = CupyEmitter()
    kerns = [visitor.visit(ex)]
    for kern in kerns: