(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True)
| 8 | |
| 9 | class KernelCountException(Exception): pass |
| 10 | def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True): |
| 11 | if to_prerealize: |
| 12 | with Context(DEBUG=0, TRACK_MATCH_STATS=0): Tensor.realize(*to_prerealize) |
| 13 | if isinstance(t, Tensor): linear, var_vals = t.linear_with_vars() |
| 14 | elif isinstance(t, list) and isinstance(t[0], Tensor): linear, var_vals = Tensor.linear_with_vars(*t) |
| 15 | else: |
| 16 | assert isinstance(t, UOp), f"can't schedule {t}" |
| 17 | linear, var_vals = Tensor(t).linear_with_vars() |
| 18 | kernel_cnt = sum((len(call.device) if isinstance(call.device, tuple) else 1) |
| 19 | for call in linear.src if call.src[0].op is Ops.SINK or not filter_sink) |
| 20 | if kernel_cnt != allowed: |
| 21 | print(f"SCHEDULE ISSUE, expecting {allowed} got {kernel_cnt}") |
| 22 | if DEBUG >= 3: |
| 23 | for i,call in enumerate(linear.src): |
| 24 | print("kernel", i+1) |
| 25 | print(call.src[0]) |
| 26 | raise KernelCountException(f"{kernel_cnt} != {allowed}") |
| 27 | # test compiling the linear |
| 28 | compile_linear(linear) |
| 29 | return linear, var_vals |
| 30 | |
| 31 | def _realize_weights(m): |
| 32 | for p in nn.state.get_parameters(m): p.realize() |
no test coverage detected
searching dependent graphs…