(call:UOp, resolved:list[UOp])
| 140 | def resolve_params(call:UOp, inputs:tuple[UOp, ...]) -> list[UOp]: return [_resolve(b, inputs) for b in get_call_arg_uops(call)] |
| 141 | |
| 142 | def unwrap_multi(call:UOp, resolved:list[UOp]) -> Iterator[tuple[list[Buffer], dict[str, int]]]: |
| 143 | bufs = [b.buffer for b in resolved] |
| 144 | if not any(isinstance(b, MultiBuffer) for b in bufs): yield cast(list[Buffer], bufs), {} |
| 145 | else: |
| 146 | dnum = next((x.expr for x in call.src[0].variables() if x.expr == '_device_num'), None) |
| 147 | for j, per_dev in enumerate(zip(*[cast(MultiBuffer, b).bufs for b in bufs])): yield list(per_dev), {dnum: j} if dnum else {} |
| 148 | |
| 149 | def exec_view(ctx:ExecContext, call:UOp, ast:UOp) -> float|None: |
| 150 | resolved = resolve_params(call, ctx.input_uops) |
searching dependent graphs…