(ctx:UOp, k:UOp, needed:set[int])
| 20 | return body.substitute({p: p.replace(arg=j) for j,(_, p) in enumerate(used)}, walk=True), tuple(all_args[i] for i,_ in used) |
| 21 | |
| 22 | def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]: |
| 23 | fxn, args = k.src[0], k.src[1:] |
| 24 | if k.arg.grad_fxn is not None: |
| 25 | if ctx.op is Ops.TUPLE: |
| 26 | real = [g for g in ctx.src if g.op is not Ops.NOOP] |
| 27 | return (None,) + (k.arg.grad_fxn(*real, call=k) if len(real) > 1 else k.arg.grad_fxn(real[0], k)) |
| 28 | return (None,) + k.arg.grad_fxn(ctx, k) |
| 29 | assert fxn.op is Ops.TUPLE, f"expected TUPLE body for gradient, got {fxn.op}" |
| 30 | params = {x.arg:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM} |
| 31 | grad_args = ctx.src |
| 32 | root_grad = UOp(Ops.TUPLE, src=tuple(UOp(Ops.NOOP) if g.op is Ops.NOOP else g.param_like(len(args)+i) for i,g in enumerate(grad_args))) |
| 33 | grads = compute_gradient(fxn, root_grad, set(params.values())) |
| 34 | # for precompiled calls, substitute forward outputs with params so intermediates aren't recomputed |
| 35 | fwd_subs = {src: src.param_like(len(args)+len(grad_args)+i) for i, src in enumerate(fxn.src)} if k.arg.precompile else {} |
| 36 | fwd_outs = tuple(k.gettuple(i) for i in range(len(fxn.src))) if k.arg.precompile else () |
| 37 | # collect needed gradient bodies, compact unused params, create a single backward CALL |
| 38 | grad_bodies = [(i, grads[p]) for i in needed if (p:=params.get(i)) is not None and p in grads] |
| 39 | bwd_body = UOp.maketuple(*(gb for _, gb in grad_bodies)).substitute(fwd_subs, walk=True) |
| 40 | bwd_body, compact_args = _compact_params(bwd_body, (*args, *grad_args, *fwd_outs)) |
| 41 | # TODO: is this okay here? |
| 42 | from tinygrad.function import pm_transform_unique_const |
| 43 | bwd_body = graph_rewrite(bwd_body, pm_transform_unique_const, ctx=(None, itertools.count(0))) |
| 44 | bwd_call = bwd_body.call(*compact_args, name=(k.arg.name or "")+"_backward", precompile=k.arg.precompile_backward) |
| 45 | gb_map = {i: idx for idx, (i, _) in enumerate(grad_bodies)} |
| 46 | return (None,) + tuple(bwd_call.gettuple(gb_map[i]) if i in gb_map else None for i in range(len(args))) |
| 47 | |
| 48 | # ctx is grad_output |
| 49 | pm_gradient = PatternMatcher([ |
no test coverage detected
searching dependent graphs…