~magnusmorton/delayrepay

75f9343cbef3fc7ce6775b2c1901ad27c207e37b — Magnus Morton 6 months ago dec613b
cleaned up. imported ir
1 files changed, 12 insertions(+), 14 deletions(-)

M delayrepay/delayarray.py
M delayrepay/delayarray.py => delayrepay/delayarray.py +12 -14
@@ 17,12 17,10 @@ this program. If not, see <http://www.gnu.org/licenses/>.
# Copyright (C) 2020 by Univeristy of Edinburgh

from numbers import Number
from typing import Any, List, Dict, Tuple, Optional, Union, Set
import numpy as np  # type: ignore
import numpy.lib.mixins  # type: ignore
import delayrepay.cuda as backend

Shape = Tuple[int, int]
import delayrepay.ir as ir


def cast(func):


@@ 31,7 29,7 @@ def cast(func):
    def wrapper(*args, **kwargs):
        arr = func(*args, **kwargs)
        if not isinstance(arr, DelayArray):
            arr = NPArray(arr)
            arr = ir.NPArray(arr)
        return arr

    return wrapper


@@ 65,19 63,19 @@ class DelayArray(numpy.lib.mixins.NDArrayOperatorsMixin):
        if ufunc.__name__ == "matmul":
            return self._dot(inputs, kwargs)
        # cls = func_to_numpy_ex(ufunc)
        args = [arg_to_numpy_ex(arg) for arg in inputs]
        return create_ex(ufunc, args)
        args = [ir.arg_to_numpy_ex(arg) for arg in inputs]
        return ir.create_ex(ufunc, args)

    def _dot_mv(self, args, kwargs):
        return MVEx(args[0], args[1])
        return ir.MVEx(args[0], args[1])

    def _dot_mm(self, args, kwargs):
        return MMEx(args[0], args[1])
        return ir.MMEx(args[0], args[1])

    @cast
    def _dot(self, args, kwargs):
        # scalar result dot
        args = [arg_to_numpy_ex(arg) for arg in args]
        args = [ir.arg_to_numpy_ex(arg) for arg in args]
        # if is_matrix_matrix(args[0].shape, args[1].shape):
        #     return self._dot_mm(args, kwargs)
        # if is_matrix_vector(args[0].shape, args[1].shape):


@@ 87,7 85,7 @@ class DelayArray(numpy.lib.mixins.NDArrayOperatorsMixin):
        right = args[1].__array__()

        # TODO: independent fallback mechanism
        return np.dot(left, right)
        return backend.fallback.dot(left, right)

    def __array_function__(self, func, types, args, kwargs):
        if func.__name__ == "dot":


@@ 101,7 99,7 @@ class DelayArray(numpy.lib.mixins.NDArrayOperatorsMixin):
        return less(self, other)

    def dot(self, other, out=None):
        return np.dot(self, other)
        return backend.fallback.dot(self, other)

    def get(self):
        return self.__array__().get()


@@ 110,7 108,7 @@ class DelayArray(numpy.lib.mixins.NDArrayOperatorsMixin):
        self.__array__()

    def reshape(self, *args, **kwargs):
        return NPArray(self.__array__().reshape(*args, **kwargs))
        return ir.NPArray(self.__array__().reshape(*args, **kwargs))

    def __setitem__(self, key, item):
        arr = self.__array__()


@@ 160,10 158,10 @@ def implements(np_function):

@implements(np.diag)
def diag(arr, k=0):
    if isinstance(arr.ex, NPArray):
    if isinstance(arr.ex, ir.NPArray):
        arr._ndarray = np.ascontiguousarray(np.diag(arr._ndarray, k))
        assert arr._ndarray.flags["C_CONTIGUOUS"]
        arr.ex = NPArray(arr._ndarray)
        arr.ex = ir.NPArray(arr._ndarray)
        return arr
    else:
        return NotImplemented