(call:UOp)
| 119 | return UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg).multi(multi.axis) |
| 120 | |
| 121 | def rewrite_into_function(call:UOp): |
| 122 | if call.arg.precompile: return None |
| 123 | new_body = graph_rewrite(call.src[0], multi_pm, name="subcall") |
| 124 | new_args = tuple(a.src[0] if a.op is Ops.MULTI else a for a in call.src[1:]) |
| 125 | # after multi resolution, TUPLE elements may be MULTI — strip MULTI from body, create per-shard FUNCTION, wrap each GETTUPLE in its own MULTI |
| 126 | assert new_body.op is Ops.TUPLE |
| 127 | if any(s.op is Ops.MULTI for s in new_body.src): |
| 128 | shard_call = call.replace(src=(UOp.maketuple(*[s.src[0] if s.op is Ops.MULTI else s for s in new_body.src]),)+new_args) |
| 129 | return UOp.maketuple(*[shard_call.gettuple(i).multi(s.axis) if s.op is Ops.MULTI else shard_call.gettuple(i) for i, s in enumerate(new_body.src)]) |
| 130 | return call.replace(src=(new_body,)+new_args) |
| 131 | |
| 132 | def param_to_multi(p:UOp): |
| 133 | if p.axis is None: return None |
nothing calls this directly
no test coverage detected
searching dependent graphs…