(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw)
| 43 | def __init__(self, name:str, lib:bytes, **kwargs): |
| 44 | self.uops: list[tuple[Ops, DType, list[int], Any]] = pickle.loads(lib) |
| 45 | def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw): |
| 46 | st = time.perf_counter() |
| 47 | warp = list(itertools.product(*[range(x) for x in local_size[::-1]])) |
| 48 | warp_size = len(warp) |
| 49 | void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP, Ops.STORE} |
| 50 | loop_ends: dict[int, int] = {srcs[1]:i for i, (uop, _, srcs, _) in enumerate(self.uops) if uop == Ops.END} |
| 51 | for idxs in itertools.product(*[range(x) for x in global_size[::-1]]): |
| 52 | values: dict[int, Any] = {} |
| 53 | pbufs: list[memoryview] = list(bufs) |
| 54 | pvals: list[int] = list(vals) |
| 55 | i = 0 |
| 56 | while i < len(self.uops): |
| 57 | uop, dtype, srcs, arg = self.uops[i] |
| 58 | src_values = [values[v] for v in srcs if self.uops[v][0] not in void_ops] |
| 59 | src_dtypes = [self.uops[v][1] for v in srcs if self.uops[v][0] not in void_ops] |
| 60 | if getenv("TRACE"): print(i, uop, dtype, arg, src_values, src_dtypes) |
| 61 | if uop is Ops.END: |
| 62 | i = srcs[1] |
| 63 | continue |
| 64 | if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP): |
| 65 | # in the python emulator, the warp is always in sync |
| 66 | i += 1 |
| 67 | continue |
| 68 | assert dtype is not None, f"{uop} is missing a dtype" |
| 69 | if uop is Ops.STORE: |
| 70 | store_gate = src_values[2] if len(src_values) >= 3 else [True] * warp_size |
| 71 | for j,val in enumerate(src_values[1] if src_dtypes[1].count > 1 else [src_values[1]]): |
| 72 | for (m,o),v,g in zip(src_values[0], val, store_gate): |
| 73 | if g: _store(m, o+j, v, src_dtypes[1].scalar()) |
| 74 | i += 1 |
| 75 | continue |
| 76 | if uop is Ops.AFTER: values[i] = src_values[0] |
| 77 | elif uop in {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}: |
| 78 | assert isinstance(dtype, PtrDType), dtype |
| 79 | storage_fmt = storage_fmt_for_dtype(dtype.base.scalar()) |
| 80 | if storage_fmt is None: raise RuntimeError(f"{dtype=} is not supported") |
| 81 | if TYPE_CHECKING or sys.version_info < (3, 12): assert storage_fmt != "e" |
| 82 | if uop is Ops.DEFINE_REG: |
| 83 | # REGs are per thread |
| 84 | values[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)] |
| 85 | else: |
| 86 | buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is not Ops.PARAM else pbufs.pop(0) |
| 87 | values[i] = [buf.cast(storage_fmt)] * warp_size |
| 88 | elif uop is Ops.DEFINE_VAR: |
| 89 | values[i] = [pvals.pop(0)] * warp_size |
| 90 | elif uop is Ops.SPECIAL: |
| 91 | if arg[0] == 'g': values[i] = [idxs[2-int(arg[-1])]] * warp_size |
| 92 | elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp] |
| 93 | elif uop is Ops.CONST: values[i] = [arg] * warp_size |
| 94 | elif uop is Ops.INDEX: |
| 95 | ret:list = [] |
| 96 | if isinstance(src_dtypes[0], ImageDType): |
| 97 | assert len(src_values) == 3, "image index must be 3 srcs" |
| 98 | for m,oy,ox in zip(*src_values): |
| 99 | if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None)) |
| 100 | else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4)) |
| 101 | else: |
| 102 | assert len(src_values) == 2, "non-image index must be 2 srcs" |
nothing calls this directly
no test coverage detected