Remove unused PARAMs from body and return compacted (body, args).
(body:UOp, all_args:tuple[UOp, ...])
| 15 | if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],) |
| 16 | |
| 17 | def _compact_params(body:UOp, all_args:tuple[UOp, ...]) -> tuple[UOp, tuple[UOp, ...]]: |
| 18 | """Remove unused PARAMs from body and return compacted (body, args).""" |
| 19 | used = sorted({p.arg: p for p in body.toposort() if p.op is Ops.PARAM}.items()) |
| 20 | return body.substitute({p: p.replace(arg=j) for j,(_, p) in enumerate(used)}, walk=True), tuple(all_args[i] for i,_ in used) |
| 21 | |
| 22 | def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]: |
| 23 | fxn, args = k.src[0], k.src[1:] |
no test coverage detected
searching dependent graphs…