~magnusmorton/delayrepay

e717043df8564034c25616beb20e4878f4fc0641 — Magnus Morton 6 months ago dab0aa5
renamed run_gpu; added __matmul__
1 files changed, 12 insertions(+), 4 deletions(-)

M delayrepay/delayarray.py
M delayrepay/delayarray.py => delayrepay/delayarray.py +12 -4
@@ 47,7 47,7 @@ class DelayArray(numpy.lib.mixins.NDArrayOperatorsMixin):
        try:
            return self.array
        except AttributeError:
            self.array = backend.run_gpu(self)
            self.array = backend.run(self)
            return self.array

    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):


@@ 61,6 61,7 @@ class DelayArray(numpy.lib.mixins.NDArrayOperatorsMixin):
                            left.__array__(), right.__array__()
                        )
        if ufunc.__name__ == "matmul":
            return None
            return self._dot(inputs, kwargs)
        # cls = func_to_numpy_ex(ufunc)
        args = [arg_to_numpy_ex(arg) for arg in inputs]


@@ 72,6 73,9 @@ class DelayArray(numpy.lib.mixins.NDArrayOperatorsMixin):
    def _dot_mm(self, args, kwargs):
        return MMEx(args[0], args[1])

    def __matmul__(self, other):
        return self._dot([self, other], {})

    @cast
    def _dot(self, args, kwargs):
        # scalar result dot


@@ 102,7 106,11 @@ class DelayArray(numpy.lib.mixins.NDArrayOperatorsMixin):
        return self._dot(other, out)

    def get(self):
        return self.__array__().get()
        arr = self.__array__()
        try:
            return arr.get()
        except AttributeError:
            return arr

    def run(self):
        self.__array__()


@@ 177,8 185,8 @@ FUNCS = {
ufunc_lookup = {
    "matmul": backend.np.matmul,
    "add": backend.np.add,
    "multiply": backend.np.cupy.multiply,
    "subtract": backend.np.cupy.subtract,
    "multiply": backend.np.multiply,
    "subtract": backend.np.subtract,
    "true_divide": backend.np.true_divide,
}