MCPcopy
hub / github.com/tinygrad/tinygrad / __call__

Method __call__

tinygrad/runtime/ops_python.py:45–202  ·  view source on GitHub ↗
(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw)

Source from the content-addressed store, hash-verified

43 def __init__(self, name:str, lib:bytes, **kwargs):
44 self.uops: list[tuple[Ops, DType, list[int], Any]] = pickle.loads(lib)
45 def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw):
46 st = time.perf_counter()
47 warp = list(itertools.product(*[range(x) for x in local_size[::-1]]))
48 warp_size = len(warp)
49 void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP, Ops.STORE}
50 loop_ends: dict[int, int] = {srcs[1]:i for i, (uop, _, srcs, _) in enumerate(self.uops) if uop == Ops.END}
51 for idxs in itertools.product(*[range(x) for x in global_size[::-1]]):
52 values: dict[int, Any] = {}
53 pbufs: list[memoryview] = list(bufs)
54 pvals: list[int] = list(vals)
55 i = 0
56 while i < len(self.uops):
57 uop, dtype, srcs, arg = self.uops[i]
58 src_values = [values[v] for v in srcs if self.uops[v][0] not in void_ops]
59 src_dtypes = [self.uops[v][1] for v in srcs if self.uops[v][0] not in void_ops]
60 if getenv("TRACE"): print(i, uop, dtype, arg, src_values, src_dtypes)
61 if uop is Ops.END:
62 i = srcs[1]
63 continue
64 if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP):
65 # in the python emulator, the warp is always in sync
66 i += 1
67 continue
68 assert dtype is not None, f"{uop} is missing a dtype"
69 if uop is Ops.STORE:
70 store_gate = src_values[2] if len(src_values) >= 3 else [True] * warp_size
71 for j,val in enumerate(src_values[1] if src_dtypes[1].count > 1 else [src_values[1]]):
72 for (m,o),v,g in zip(src_values[0], val, store_gate):
73 if g: _store(m, o+j, v, src_dtypes[1].scalar())
74 i += 1
75 continue
76 if uop is Ops.AFTER: values[i] = src_values[0]
77 elif uop in {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
78 assert isinstance(dtype, PtrDType), dtype
79 storage_fmt = storage_fmt_for_dtype(dtype.base.scalar())
80 if storage_fmt is None: raise RuntimeError(f"{dtype=} is not supported")
81 if TYPE_CHECKING or sys.version_info < (3, 12): assert storage_fmt != "e"
82 if uop is Ops.DEFINE_REG:
83 # REGs are per thread
84 values[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)]
85 else:
86 buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is not Ops.PARAM else pbufs.pop(0)
87 values[i] = [buf.cast(storage_fmt)] * warp_size
88 elif uop is Ops.DEFINE_VAR:
89 values[i] = [pvals.pop(0)] * warp_size
90 elif uop is Ops.SPECIAL:
91 if arg[0] == 'g': values[i] = [idxs[2-int(arg[-1])]] * warp_size
92 elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp]
93 elif uop is Ops.CONST: values[i] = [arg] * warp_size
94 elif uop is Ops.INDEX:
95 ret:list = []
96 if isinstance(src_dtypes[0], ImageDType):
97 assert len(src_values) == 3, "image index must be 3 srcs"
98 for m,oy,ox in zip(*src_values):
99 if ox < 0 or ox >= src_dtypes[0].shape[1] or oy < 0 or oy >= src_dtypes[0].shape[0]: ret.append((m, None))
100 else: ret.append((m, ox*4 + oy*src_dtypes[0].shape[1]*4))
101 else:
102 assert len(src_values) == 2, "non-image index must be 2 srcs"

Callers

nothing calls this directly

Calls 13

getenvFunction · 0.90
storage_fmt_for_dtypeFunction · 0.90
bitcastFunction · 0.90
get_single_elementFunction · 0.90
all_sameFunction · 0.90
exec_aluFunction · 0.90
_storeFunction · 0.85
scalarMethod · 0.80
appendMethod · 0.80
loadFunction · 0.70
castMethod · 0.45
getMethod · 0.45

Tested by

no test coverage detected