MCPcopy
hub / github.com/tinygrad/tinygrad / compute_gradient

Function compute_gradient

tinygrad/gradient.py:95–131  ·  view source on GitHub ↗
(root:UOp, root_grad:UOp, targets:set[UOp])

Source from the content-addressed store, hash-verified

93 return [node for node in in_target_path if node.op is not Ops.DETACH and in_target_path[node]], in_target_path
94
95def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]:
96 walk, in_target_path = _deepwalk(root, targets)
97 grads: dict[UOp, UOp] = {root: root_grad}
98 for t0 in reversed(walk):
99 if t0 not in grads or grads[t0].op is Ops.NOOP: continue
100 # GETTUPLE: accumulate gradient into a TUPLE UOp on the FUNCTION, process when we hit the FUNCTION
101 if t0.op is Ops.GETTUPLE:
102 k = t0.src[0] # the FUNCTION
103 assert k.op is Ops.FUNCTION and k.src[0].op is Ops.TUPLE
104 n_outputs = len(k.src[0].src)
105 prev = grads[k].src if k in grads else tuple(UOp(Ops.NOOP) for _ in range(n_outputs))
106 grads[k] = UOp.maketuple(*(prev[i] + grads[t0] if i == t0.arg and prev[i].op is not Ops.NOOP else
107 grads[t0] if i == t0.arg else prev[i] for i in range(n_outputs)))
108 continue
109 # FUNCTION/CALL: pass needed param set so backward only computes required gradients
110 # (FUNCTION uses implicit TUPLE gradient or grad_fxn; CALL requires an explicit grad_fxn)
111 if t0.op in {Ops.FUNCTION, Ops.CALL}:
112 needed = {i for i, arg in enumerate(t0.src[1:]) if arg in targets or in_target_path.get(arg, False)}
113 lgrads:tuple[UOp|None, ...]|None = call_gradient(grads[t0], t0, needed)
114 else:
115 lgrads = cast(tuple[UOp|None, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
116 if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...")
117 assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}"
118 for k,v in zip(t0.src, lgrads):
119 if v is None: continue
120 if k in grads and grads[k].op is not Ops.NOOP:
121 if v.op is Ops.TUPLE and grads[k].op is Ops.TUPLE:
122 grads[k] = UOp.maketuple(*(p + n if (p.op is not Ops.NOOP and n.op is not Ops.NOOP) else
123 n if p.op is Ops.NOOP else p for p, n in zip(grads[k].src, v.src)))
124 else: grads[k] = grads[k] + v
125 else: grads[k] = v
126 if len(forward_metadata:=all_metadata.get(t0, ())):
127 backward_metadata = tuple(dataclasses.replace(x, backward=True) for x in forward_metadata)
128 # we add the backward metadata to everything new in the graph
129 for bw_uop in v.toposort(lambda x: x not in (t0, *t0.src, grads[t0])):
130 all_metadata[bw_uop] = all_metadata.get(bw_uop, ())+backward_metadata
131 return grads

Callers 4

gradientMethod · 0.90
call_gradientFunction · 0.85

Calls 9

UOpClass · 0.90
_deepwalkFunction · 0.85
call_gradientFunction · 0.85
castFunction · 0.85
maketupleMethod · 0.80
toposortMethod · 0.80
getMethod · 0.45
rewriteMethod · 0.45
replaceMethod · 0.45

Tested by 2

Used in the wild real call sites across dependent graphs

searching dependent graphs…