(big_sink:UOp)
| 120 | |
| 121 | @track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[0].src))}") |
| 122 | def create_linear_with_vars(big_sink:UOp) -> tuple[UOp, dict[str, int]]: |
| 123 | # big_sink srcs are all the Tensors |
| 124 | linear_call = graph_rewrite(big_sink, pm_schedule, name="schedule to linear", enter_calls=True) |
| 125 | |
| 126 | # this recursively resolves the linear_call and allocates buffers |
| 127 | linear = graph_rewrite(linear_call, pm_resolve_linear_call, name="resolve linear call") |
| 128 | |
| 129 | # vars used in the schedule |
| 130 | used_vars = set().union(*[{v.expr for v in si.src[0].variables()} for si in linear.src]) |
| 131 | # get var_vals |
| 132 | var_vals: dict[str, int] = {} |
| 133 | for b in big_sink.src[1:]: |
| 134 | if b.op is Ops.BIND: |
| 135 | nm = b.src[0].expr |
| 136 | if nm not in used_vars: continue |
| 137 | val = b.src[1].arg |
| 138 | if var_vals.get(nm, val) != val: raise RuntimeError(f"bind mismatch on {nm}, {var_vals[nm]} != {val}") |
| 139 | var_vals[nm] = val |
| 140 | |
| 141 | # jit captures this schedule, no need to execute. |
| 142 | if len(capturing) and CAPTURING: |
| 143 | capturing[0].add_linear(linear, var_vals) |
| 144 | return UOp(Ops.LINEAR, src=()), var_vals |
| 145 | |
| 146 | held_bufs = ({b for b in linear_call.src[1:] if b.op is Ops.BUFFER} if linear_call.op is Ops.CALL else set()) |
| 147 | return memory_plan_rewrite(linear, held_bufs), var_vals |
no test coverage detected
searching dependent graphs…