~magnusmorton/delayrepay

ref: 5052fe013379135757f6e26effed71e083221f26 delayrepay/delayrepay/cuda.py -rw-r--r-- 5.3 KiB
5052fe01 — Magnus Morton env based backend select 5 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
""" CUDA numpy array with delayed-evaluation semantics

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with
this program. If not, see <http://www.gnu.org/licenses/>.
"""
# Copyright (C) 2020 by Univeristy of Edinburgh

from typing import List, Dict, Tuple
import cupy  # type: ignore

from .visitor import Visitor

np = cupy
fallback = cupy

Shape = Tuple[int, int]

InputDict = Dict[str, "BaseFragment"]


def is_ndarray(arr):
    return isinstance(arr, cupy.core.core.ndarray)


class BaseFragment:
    def __init__(self):
        self.name = None
        self.stmts = []
        self._expr = None
        self._inputs = {}

    @property
    def inputs(self) -> InputDict:
        return self._inputs

    @property
    def kernel_args(self) -> InputDict:
        return self._inputs


def dedup(seq):
    seen = set()
    seen_add = seen.add
    return [x for x in seq if not (x in seen or seen_add(x))]


class Fragment(BaseFragment):
    def __init__(self, name: str, stmts: List[str], inputs: InputDict) -> None:
        self.name = name
        self.stmts = stmts
        self._inputs = inputs
        # self.dtype = np.float32

    def ref(self) -> str:
        return self.name

    # def expr(self) -> str:
    #     return self._expr

    def to_input(self):
        return {self.name: self.node.array}

    def to_kern(self) -> cupy.ElementwiseKernel:
        body = ";\n".join(dedup(self.stmts))
        inargs = [f"T {arg}" for arg in self.kernel_args]
        kern = cupy.ElementwiseKernel(
            ",".join(inargs),
            "T out",
            f"{body};\nout = {self.name}",
            f"delay_repay_{self.name}",
        )
        return kern


class InputFragment(BaseFragment):
    def __init__(self, name, arr) -> None:
        super().__init__()
        self.name = name
        self._inputs = {self.name: arr}

    def ref(self) -> str:
        return f"{self.name}"

    def expr(self) -> str:
        return f"{self.name}"


class ScalarFragment(BaseFragment):
    def __init__(self, val) -> None:
        super().__init__()
        self.val = val.val
        self.dtype = val.dtype

    def ref(self) -> str:
        return str(self.val)

    def expr(self) -> str:
        return str(self.val)


class ReductionKernel(Fragment):
    def to_kern(self):
        kern = cupy.ReductionKernel(
            ",".join(self.inargs),
            "T out",
            self.expr,
            self.redex,
            "out = a",
            0,
            self.name,
        )
        return kern


def combine_inputs(*args: InputDict) -> InputDict:
    ret = {}
    for arg in args:
        ret.update(arg)
    return ret


class CupyEmitter(Visitor):
    def __init__(self):
        super().__init__()
        self.ins = {}
        self.outs = []
        self.kernels = []
        self.seen = {}
        self.count = 0

    def visit(self, node):
        if node in self.seen:
            visited = self.seen[node]
        else:
            visited = super().visit(node)
            self.seen[node] = visited
            self.count += 1
        return visited

    def visit_BinaryNumpyEx(self, node) -> BaseFragment:
        op = node.to_op()
        left = self.visit(node.children[0])
        right = self.visit(node.children[1])
        name = f"binex{self.count}"
        stmt = f"T {name} = {left.ref()} {op} {right.ref()}"
        stmts = left.stmts + right.stmts + [stmt]
        return Fragment(name, stmts, combine_inputs(left.inputs, right.inputs))

    def visit_UnaryFuncEx(self, node) -> BaseFragment:
        inner = self.visit(node.children[0])
        name = f"unfunc{self.count}"
        stmts = inner.stmts + [f"T {name} = {node.to_op()}({inner.ref()})"]
        return Fragment(name, stmts, inner.inputs)

    def visit_BinaryFuncEx(self, node) -> BaseFragment:
        op = node.to_op()
        left = self.visit(node.children[0])
        right = self.visit(node.children[1])
        name = f"binfunc{self.count}"
        stmt = f"T {name} = {op}({left.ref()}, {right.ref()})"
        stmts = left.stmts + right.stmts + [stmt]
        return Fragment(name, stmts, combine_inputs(left.inputs, right.inputs))

    def visit_NPArray(self, node) -> BaseFragment:
        return InputFragment(f"arr{self.count}", node)

    def visit_NPRef(self, node) -> BaseFragment:
        return InputFragment(f"ref{self.count}", node)

    def visit_Scalar(self, node) -> BaseFragment:
        return ScalarFragment(node)

    def visit_ReduceEx(self, node) -> BaseFragment:
        inner = self.visit(node.children[0])
        name = node.name
        op = node.to_op()

        return NotImplemented


def run(ex) -> cupy.array:
    visitor = CupyEmitter()
    kerns = [visitor.visit(ex)]
    for kern in kerns:
        compiled = kern.to_kern()
        inputs = [value.array for key, value in kern.kernel_args.items()]
        ret = compiled(*inputs)
        kern.array = ret
    return ret