MCPcopy
hub / github.com/tinygrad/tinygrad / call_gradient

Function call_gradient

tinygrad/gradient.py:22–46  ·  view source on GitHub ↗
(ctx:UOp, k:UOp, needed:set[int])

Source from the content-addressed store, hash-verified

20 return body.substitute({p: p.replace(arg=j) for j,(_, p) in enumerate(used)}, walk=True), tuple(all_args[i] for i,_ in used)
21
22def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
23 fxn, args = k.src[0], k.src[1:]
24 if k.arg.grad_fxn is not None:
25 if ctx.op is Ops.TUPLE:
26 real = [g for g in ctx.src if g.op is not Ops.NOOP]
27 return (None,) + (k.arg.grad_fxn(*real, call=k) if len(real) > 1 else k.arg.grad_fxn(real[0], k))
28 return (None,) + k.arg.grad_fxn(ctx, k)
29 assert fxn.op is Ops.TUPLE, f"expected TUPLE body for gradient, got {fxn.op}"
30 params = {x.arg:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
31 grad_args = ctx.src
32 root_grad = UOp(Ops.TUPLE, src=tuple(UOp(Ops.NOOP) if g.op is Ops.NOOP else g.param_like(len(args)+i) for i,g in enumerate(grad_args)))
33 grads = compute_gradient(fxn, root_grad, set(params.values()))
34 # for precompiled calls, substitute forward outputs with params so intermediates aren't recomputed
35 fwd_subs = {src: src.param_like(len(args)+len(grad_args)+i) for i, src in enumerate(fxn.src)} if k.arg.precompile else {}
36 fwd_outs = tuple(k.gettuple(i) for i in range(len(fxn.src))) if k.arg.precompile else ()
37 # collect needed gradient bodies, compact unused params, create a single backward CALL
38 grad_bodies = [(i, grads[p]) for i in needed if (p:=params.get(i)) is not None and p in grads]
39 bwd_body = UOp.maketuple(*(gb for _, gb in grad_bodies)).substitute(fwd_subs, walk=True)
40 bwd_body, compact_args = _compact_params(bwd_body, (*args, *grad_args, *fwd_outs))
41 # TODO: is this okay here?
42 from tinygrad.function import pm_transform_unique_const
43 bwd_body = graph_rewrite(bwd_body, pm_transform_unique_const, ctx=(None, itertools.count(0)))
44 bwd_call = bwd_body.call(*compact_args, name=(k.arg.name or "")+"_backward", precompile=k.arg.precompile_backward)
45 gb_map = {i: idx for idx, (i, _) in enumerate(grad_bodies)}
46 return (None,) + tuple(bwd_call.gettuple(gb_map[i]) if i in gb_map else None for i in range(len(args)))
47
48# ctx is grad_output
49pm_gradient = PatternMatcher([

Callers 1

compute_gradientFunction · 0.85

Calls 13

UOpClass · 0.90
graph_rewriteFunction · 0.90
compute_gradientFunction · 0.85
_compact_paramsFunction · 0.85
toposortMethod · 0.80
param_likeMethod · 0.80
gettupleMethod · 0.80
substituteMethod · 0.80
maketupleMethod · 0.80
grad_fxnMethod · 0.45
getMethod · 0.45
countMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…