MCPcopy
hub / github.com/tinygrad/tinygrad / check_schedule

Function check_schedule

test/null/test_schedule.py:10–29  ·  view source on GitHub ↗
(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True)

Source from the content-addressed store, hash-verified

8
9class KernelCountException(Exception): pass
10def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True):
11 if to_prerealize:
12 with Context(DEBUG=0, TRACK_MATCH_STATS=0): Tensor.realize(*to_prerealize)
13 if isinstance(t, Tensor): linear, var_vals = t.linear_with_vars()
14 elif isinstance(t, list) and isinstance(t[0], Tensor): linear, var_vals = Tensor.linear_with_vars(*t)
15 else:
16 assert isinstance(t, UOp), f"can't schedule {t}"
17 linear, var_vals = Tensor(t).linear_with_vars()
18 kernel_cnt = sum((len(call.device) if isinstance(call.device, tuple) else 1)
19 for call in linear.src if call.src[0].op is Ops.SINK or not filter_sink)
20 if kernel_cnt != allowed:
21 print(f"SCHEDULE ISSUE, expecting {allowed} got {kernel_cnt}")
22 if DEBUG >= 3:
23 for i,call in enumerate(linear.src):
24 print("kernel", i+1)
25 print(call.src[0])
26 raise KernelCountException(f"{kernel_cnt} != {allowed}")
27 # test compiling the linear
28 compile_linear(linear)
29 return linear, var_vals
30
31def _realize_weights(m):
32 for p in nn.state.get_parameters(m): p.realize()

Calls 6

ContextClass · 0.90
TensorClass · 0.90
compile_linearFunction · 0.90
realizeMethod · 0.80
linear_with_varsMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…