~magnusmorton/delayrepay

f28d5046264da9c1367bfa421a099ad1df9715e9 — Magnus Morton 6 months ago 9f32a80
black applied
1 files changed, 133 insertions(+), 121 deletions(-)

M delayrepay/cuda.py
M delayrepay/cuda.py => delayrepay/cuda.py +133 -121
@@ 22,53 22,55 @@ import cupy  # type: ignore
import numpy as np  # type: ignore
import numpy.lib.mixins  # type: ignore

Shape = Tuple[int,int]
Shape = Tuple[int, int]

OPS = {
    'matmul': '@',
    'add': '+',
    'multiply': '*',
    'subtract': '-',
    'true_divide': '/',
    "matmul": "@",
    "add": "+",
    "multiply": "*",
    "subtract": "-",
    "true_divide": "/",
}

FUNCS = {
    'power': 'pow',
    'arctan2': 'atan2',
    'absolute': 'abs',
    'sin':'sin',
    'cos':'cos',
    'tan':'tan',
    'sqrt':'sqrt',
    'log': 'log',
    "power": "pow",
    "arctan2": "atan2",
    "absolute": "abs",
    "sin": "sin",
    "cos": "cos",
    "tan": "tan",
    "sqrt": "sqrt",
    "log": "log",
    # HACK
    'negative': '-',
    'exp': 'exp',
    'tanh': 'tanh',
    'sinh': 'sinh',
    'cosh': 'cosh'

    "negative": "-",
    "exp": "exp",
    "tanh": "tanh",
    "sinh": "sinh",
    "cosh": "cosh",
}

ufunc_lookup = {
    'matmul': cupy.matmul,
    'add': cupy.add,
    'multiply': cupy.multiply,
    'subtract': cupy.subtract,
    'true_divide': cupy.true_divide
    "matmul": cupy.matmul,
    "add": cupy.add,
    "multiply": cupy.multiply,
    "subtract": cupy.subtract,
    "true_divide": cupy.true_divide,
}


def cast(func):
    '''cast to Delay array decorator'''
    """cast to Delay array decorator"""

    def wrapper(*args, **kwargs):
        arr = func(*args, **kwargs)
        if not isinstance(arr, DelayArray):
            arr = NPArray(arr)
        return arr

    return wrapper


class DelayArray(numpy.lib.mixins.NDArrayOperatorsMixin):
    
    def __init__(self, *args, **kwargs):
        self._memo = None



@@ 90,8 92,10 @@ class DelayArray(numpy.lib.mixins.NDArrayOperatorsMixin):
            if not isinstance(left, Number) and not isinstance(right, Number):
                if left.shape != right.shape:
                    if left.shape != (0,) and right.shape != (0,):
                        return ufunc_lookup[ufunc.__name__](left.__array__(), right.__array__())
        if ufunc.__name__ == 'matmul':
                        return ufunc_lookup[ufunc.__name__](
                            left.__array__(), right.__array__()
                        )
        if ufunc.__name__ == "matmul":
            return self._dot(inputs, kwargs)
        # cls = func_to_numpy_ex(ufunc)
        args = [arg_to_numpy_ex(arg) for arg in inputs]


@@ 138,7 142,7 @@ class DelayArray(numpy.lib.mixins.NDArrayOperatorsMixin):

    def reshape(self, *args, **kwargs):
        return NPArray(self.__array__().reshape(*args, **kwargs))
            

    def __setitem__(self, key, item):
        arr = self.__array__()
        if isinstance(key, DelayArray):


@@ 169,10 173,11 @@ class DelayArray(numpy.lib.mixins.NDArrayOperatorsMixin):
        if len(self.shape) == 1:
            return self
        return np.transpose(self)
    

    def repeat(self, *args, **kwargs):
        return repeat(self, *args, **kwargs)


def calc_shape(left, right, op=None):
    if left == (0,):
        return right


@@ 180,7 185,7 @@ def calc_shape(left, right, op=None):
        return left
    if op.__name__ in OPS:
        return left
    if op.__name__ == 'dot':
    if op.__name__ == "dot":
        # for now
        if len(left) > 1 and len(right) > 1:
            return (left[0], right[1])


@@ 193,8 198,8 @@ def calc_shape(left, right, op=None):


class Memoiser(type):
    '''Metaclass implementing caching'''
    
    """Metaclass implementing caching"""

    def __new__(meta, *args, **kwargs):
        cls = super(Memoiser, meta).__new__(meta, *args, **kwargs)
        meta._cache = {}


@@ 209,25 214,28 @@ class Memoiser(type):
            Memoiser._cache[key] = super(Memoiser, cls).__call__(*args)
        return cls._cache[key]


def reset():
    # hacks
    Memoiser._cache.clear()


class NumpyEx(DelayArray, metaclass=Memoiser):
    children : List['NumpyEx']
    '''Numpy expression'''
    def __init__(self, children: List['NumpyEx']=[]):
    children: List["NumpyEx"]
    """Numpy expression"""

    def __init__(self, children: List["NumpyEx"] = []):
        super().__init__()
        self.dtype = None
        self.children = children

    def __hash__(self):
        '''
        """
        Should work because of the Memoizer
        '''
        """
        return id(self)


class Funcable:
    def to_op(self):
        return OPS[self.func.__name__]


@@ 244,7 252,6 @@ class ReduceEx(NumpyEx, Funcable):


class UnaryFuncEx(NumpyEx, Funcable):

    def __init__(self, func, arg):
        super().__init__(children=[arg])
        self.func = func


@@ 256,16 263,16 @@ class UnaryFuncEx(NumpyEx, Funcable):


class BinaryFuncEx(NumpyEx):

    def __init__(self, func, left, right):
        super().__init__(children=[left,right])
        super().__init__(children=[left, right])
        self.func = func
        self.shape = calc_shape(left.shape, right.shape, func)
        self.dtype = calc_type(left, right)
        

    def to_op(self):
        return FUNCS[self.func.__name__]


def pow_ex(func, left, right):
    if not isinstance(right.val, int):
        return BinaryFuncEx(func, left, right)


@@ 277,29 284,27 @@ def pow_ex(func, left, right):
    return ex



def create_ex(func, args):
    if func.__name__ in OPS:
        return BinaryNumpyEx(func, *args)
    if func.__name__ == 'square':
    if func.__name__ == "square":
        return BinaryNumpyEx(multiply, args[0], args[0])
    if len(args) == 1:
        return UnaryFuncEx(func, *args)
    if func.__name__ == 'power':
    if func.__name__ == "power":
        return pow_ex(func, *args)
    return BinaryFuncEx(func, *args)


class BinaryNumpyEx(NumpyEx, Funcable):
    '''Binary numpy expression'''
    """Binary numpy expression"""

    def __init__(self, func, left, right):
        super().__init__(children=[left,right])
        super().__init__(children=[left, right])
        self.func = func
        self.shape = calc_shape(left.shape, right.shape, func)
        self.dtype = calc_type(left, right)

  

class MMEx(NumpyEx, Funcable):
    # arg1: NumpyEx


@@ 322,7 327,6 @@ class MVEx(NumpyEx, Funcable):


class DotEx(NumpyEx, Funcable):

    def __init__(self, left, right):
        super().__init__()
        self.arg1 = left


@@ 331,10 335,8 @@ class DotEx(NumpyEx, Funcable):
        self._inshape = left.shape




class NPArray(NumpyEx):
    '''ndarray'''
    """ndarray"""

    def __init__(self, array):
        super().__init__()


@@ 354,16 356,17 @@ class NPArray(NumpyEx):
    def astype(self, *args, **kwargs):
        old = self.array
        cast_arr = self.array.astype(*args, **kwargs)
        del(NPArray._cache[id(old)])
        del NPArray._cache[id(old)]
        NPArray._cache[id(cast_arr)] = self
        self.array = cast_arr
        self.dtype = cast_arr.dtype
        return self


class NPRef(NumpyEx):
    '''Only for when breaking dependency chains for fusion'''
    
    def __init__(self, node:NumpyEx, shape:Shape):
    """Only for when breaking dependency chains for fusion"""

    def __init__(self, node: NumpyEx, shape: Shape):
        super().__init__()
        self.ref = node
        self.children = []


@@ 373,8 376,10 @@ class NPRef(NumpyEx):
    def array(self):
        return self.ref.array


class Scalar(NumpyEx):
    '''a scalar'''
    """a scalar"""

    # val: Number
    def __init__(self, val):
        super().__init__()


@@ 386,13 391,14 @@ class Scalar(NumpyEx):


class Visitor:
    '''Visitor ABC'''
    """Visitor ABC"""

    def visit(self, node) -> Any:
        """Visit a node."""
        if isinstance(node, list):
            visitor = self.list_visit
        else:
            method = 'visit_' + node.__class__.__name__
            method = "visit_" + node.__class__.__name__
            visitor = getattr(self, method, self.default_visit)
        return visitor(node)



@@ 404,7 410,8 @@ class Visitor:


class NumpyVisitor(Visitor):
    '''Visits Numpy Expression'''
    """Visits Numpy Expression"""

    def __init__(self):
        self.visits = 0



@@ 417,7 424,7 @@ class NumpyVisitor(Visitor):
        return node

    def walk(self, tree):
        ''' top-level walk of tree'''
        """ top-level walk of tree"""
        self.visits = 0
        return self.visit(tree)



@@ 430,8 437,6 @@ def is_matrix_vector(left, right):
    return len(left) > 1 and len(right) == 1




# def calc_type(func, type1, type2):
#     if 'float64' in (type1, type2):
#         return 'float64'


@@ 442,6 447,7 @@ def is_matrix_vector(left, right):
#     else:
#         return type1


def calc_type(node1: NumpyEx, node2: NumpyEx) -> np.dtype:
    if node1.dtype is not None:
        node2.dtype = node1.dtype


@@ 455,14 461,17 @@ HANDLED_FUNCTIONS = {}

def implements(np_function):
    "Register an __array_function__ implementation for DiagonalArray objects."

    def decorator(func):
        HANDLED_FUNCTIONS[np_function] = func
        return func

    return decorator


def arg_to_numpy_ex(arg: Any) -> NumpyEx:
    from numbers import Number

    if isinstance(arg, DelayArray):
        return arg
    elif isinstance(arg, Number):


@@ 486,7 495,7 @@ def arg_to_numpy_ex(arg: Any) -> NumpyEx:
def diag(arr, k=0):
    if isinstance(arr.ex, NPArray):
        arr._ndarray = np.ascontiguousarray(np.diag(arr._ndarray, k))
        assert(arr._ndarray.flags['C_CONTIGUOUS'])
        assert arr._ndarray.flags["C_CONTIGUOUS"]
        arr.ex = NPArray(arr._ndarray)
        return arr
    else:


@@ 497,21 506,25 @@ def diag(arr, k=0):
@cast
def diagflat(arr, k=0):
    # keep it simple for now
    return np.diagflat(np.asarray(arr, order='C'))
    return np.diagflat(np.asarray(arr, order="C"))


#@implements(np.sum)
#def sum(arr, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=None):
# @implements(np.sum)
# def sum(arr, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=None):
#    print("BLAH")
#    return ReduceEx(np.add, arr)


@implements(np.var)
def var(arr, *args, **kwargs):
    return cupy.var(arr.__array__(), *args, **kwargs)


@implements(np.sum)
def sum(arr, *args, **kwargs):
    return cupy.sum(arr.__array__(), *args, **kwargs)


@implements(np.transpose)
@cast
def transpose(arr, *args, **kwargs):


@@ 523,36 536,46 @@ def transpose(arr, *args, **kwargs):
def roll(arr, *args, **kwargs):
    return cupy.roll(arr.__array__(), *args, **kwargs)


@implements(np.max)
def max(arr, *args, **kwargs):
    return cupy.max(arr.__array__(), *args, **kwargs)


@cast
@implements(np.maximum)
def maximum(arr, *args, **kwargs):
    return cupy.maximum(arr.__array__(), *args, **kwargs)


@implements(np.average)
def average(arr, *args, **kwargs):
    return cupy.average(arr.__array__(), *args, **kwargs)


@implements(np.repeat)
@cast
def repeat(arr, *args, **kwargs):
    return cupy.repeat(arr.__array__(), *args, **kwargs)


@cast
@implements(np.cumsum)
def cumsum(arr, *args, **kwargs):
    return cupy.cumsum(arr.__array__(), *args, **kwargs)


@implements(np.greater)
def greater(arr1, arr2, *args, **kwargs):
    return cupy.greater(arr1.__array__(), arr2, *args, **kwargs)


@implements(np.less)
def less(arr1, arr2, *args, **kwargs):
    return cupy.less(arr1.__array__(), arr2, *args, **kwargs)
#sum = cast(cupy.sum)


# sum = cast(cupy.sum)
add = np.add
multiply = np.multiply
dot = np.dot


@@ 572,7 595,7 @@ square = np.square
abs = np.abs
newaxis = cupy.newaxis

#dtypes etc.
# dtypes etc.
double = np.double
float32 = np.float32
uint32 = np.uint32


@@ 589,6 612,7 @@ zeros_like = cast(cupy.zeros_like)
full = cast(cupy.full)
full_like = cast(cupy.full_like)


@implements(np.tile)
@cast
def tile(arr, *args, **kwargs):


@@ 598,6 622,7 @@ def tile(arr, *args, **kwargs):
        print(type(temp))
    return cupy.tile(temp, *args, **kwargs)


# From existing data

array = cast(cupy.array)


@@ 626,7 651,8 @@ tril = cast(cupy.tril)
triu = cast(cupy.triu)
vander = cast(np.vander)

InputDict = Dict[str, 'BaseFragment']
InputDict = Dict[str, "BaseFragment"]


class BaseFragment:
    def __init__(self):


@@ 634,7 660,7 @@ class BaseFragment:
        self.stmts = []
        self._expr = None
        self._inputs = {}
        

    @property
    def inputs(self) -> InputDict:
        return self._inputs


@@ 651,16 677,11 @@ def dedup(seq):


class Fragment(BaseFragment):

    def __init__(self,
                 name: str,
                 stmts: List[str],
                 inputs: InputDict) -> None:
    def __init__(self, name: str, stmts: List[str], inputs: InputDict) -> None:
        self.name = name
        self.stmts = stmts
        self._inputs = inputs
        #self.dtype = np.float32
        
        # self.dtype = np.float32

    def ref(self) -> str:
        return self.name


@@ 678,14 699,12 @@ class Fragment(BaseFragment):
            ",".join(inargs),
            f"T out",
            f"{body};\nout = {self.name}",
            f"delay_repay_{self.name}"
            f"delay_repay_{self.name}",
        )
        return kern



class InputFragment(BaseFragment):

    def __init__(self, name: str, arr: Union[NPArray, NPRef]) -> None:
        super().__init__()
        self.name = name


@@ 708,6 727,7 @@ class InputFragment(BaseFragment):
#              np.dtype("int32"): "",
#              np.dtype("int64"): ""}


class ScalarFragment(BaseFragment):
    def __init__(self, val: Scalar) -> None:
        super().__init__()


@@ 722,43 742,42 @@ class ScalarFragment(BaseFragment):


class ReductionKernel(Fragment):

    def to_kern(self):
        kern = cupy.ReductionKernel(','.join(self.inargs),
                                    'T out',
                                    self.expr,
                                    self.redex,
                                    'out = a',
                                    0,
                                    self.name)
        kern = cupy.ReductionKernel(
            ",".join(self.inargs),
            "T out",
            self.expr,
            self.redex,
            "out = a",
            0,
            self.name,
        )
        return kern





def combine_inputs(*args: InputDict) -> InputDict:
    ret = {}
    for arg in args:
        ret.update(arg)
    return ret


class PrettyPrinter(Visitor):
    
    def visit(self, node):
        if isinstance(node, list):
            return self.list_visit(node)
        print(type(node).__name__)
        self.visit(node.children)


class Fuser(Visitor):
    def __init__(self):
        super().__init__()
        self.seen = {}
        self.splits = []
    

    def fuse(self, node):
        

        self.visit(node)
        self.splits.append(node)
        return self.splits


@@ 779,8 798,8 @@ class Fuser(Visitor):

        return node.shape

class CupyEmitter(Visitor):

class CupyEmitter(Visitor):
    def __init__(self):
        super().__init__()
        self.ins = {}


@@ 798,54 817,47 @@ class CupyEmitter(Visitor):
            self.count += 1
        return visited

    def visit_BinaryNumpyEx(self,
                            node: BinaryNumpyEx) -> BaseFragment:
    def visit_BinaryNumpyEx(self, node: BinaryNumpyEx) -> BaseFragment:
        op = node.to_op()
        left = self.visit(node.children[0])
        right = self.visit(node.children[1])
        name = f'binex{self.count}'
        name = f"binex{self.count}"
        stmt = f"T {name} = {left.ref()} {op} {right.ref()}"
        stmts = left.stmts + right.stmts + [stmt]
        return Fragment(name, stmts, combine_inputs(left.inputs, right.inputs))

    def visit_UnaryFuncEx(self,
                          node: UnaryFuncEx) -> BaseFragment:
    def visit_UnaryFuncEx(self, node: UnaryFuncEx) -> BaseFragment:
        inner = self.visit(node.children[0])
        name = f'unfunc{self.count}'
        name = f"unfunc{self.count}"
        stmts = inner.stmts + [f"T {name} = {node.to_op()}({inner.ref()})"]
        return Fragment(name, stmts, inner.inputs)

    def visit_BinaryFuncEx(self,
                           node: BinaryFuncEx) -> BaseFragment:
    def visit_BinaryFuncEx(self, node: BinaryFuncEx) -> BaseFragment:
        op = node.to_op()
        left = self.visit(node.children[0])
        right = self.visit(node.children[1])
        name = f'binfunc{self.count}'
        name = f"binfunc{self.count}"
        stmt = f"T {name} = {op}({left.ref()}, {right.ref()})"
        stmts = left.stmts + right.stmts + [stmt]
        return Fragment(name, stmts, combine_inputs(left.inputs,
                        right.inputs))
    
    def visit_NPArray(self,
                      node: NPArray) -> BaseFragment:
        return InputFragment(f'arr{self.count}', node)

    def visit_NPRef(self,
                    node: NPRef) -> BaseFragment:
        return InputFragment(f'ref{self.count}', node)

    def visit_Scalar(self,
                     node: Scalar) -> BaseFragment:
        return Fragment(name, stmts, combine_inputs(left.inputs, right.inputs))

    def visit_NPArray(self, node: NPArray) -> BaseFragment:
        return InputFragment(f"arr{self.count}", node)

    def visit_NPRef(self, node: NPRef) -> BaseFragment:
        return InputFragment(f"ref{self.count}", node)

    def visit_Scalar(self, node: Scalar) -> BaseFragment:
        return ScalarFragment(node)

    def visit_ReduceEx(self,
                       node: ReduceEx) -> BaseFragment:
    def visit_ReduceEx(self, node: ReduceEx) -> BaseFragment:
        inner = self.visit(node.children[0])
        name = node.name
        op = node.to_op()

        return NotImplemented


def run_gpu(ex: NumpyEx) -> cupy.array:
    visitor = CupyEmitter()
    kerns = [visitor.visit(ex)]