(root:UOp, root_grad:UOp, targets:set[UOp])
| 93 | return [node for node in in_target_path if node.op is not Ops.DETACH and in_target_path[node]], in_target_path |
| 94 | |
| 95 | def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]: |
| 96 | walk, in_target_path = _deepwalk(root, targets) |
| 97 | grads: dict[UOp, UOp] = {root: root_grad} |
| 98 | for t0 in reversed(walk): |
| 99 | if t0 not in grads or grads[t0].op is Ops.NOOP: continue |
| 100 | # GETTUPLE: accumulate gradient into a TUPLE UOp on the FUNCTION, process when we hit the FUNCTION |
| 101 | if t0.op is Ops.GETTUPLE: |
| 102 | k = t0.src[0] # the FUNCTION |
| 103 | assert k.op is Ops.FUNCTION and k.src[0].op is Ops.TUPLE |
| 104 | n_outputs = len(k.src[0].src) |
| 105 | prev = grads[k].src if k in grads else tuple(UOp(Ops.NOOP) for _ in range(n_outputs)) |
| 106 | grads[k] = UOp.maketuple(*(prev[i] + grads[t0] if i == t0.arg and prev[i].op is not Ops.NOOP else |
| 107 | grads[t0] if i == t0.arg else prev[i] for i in range(n_outputs))) |
| 108 | continue |
| 109 | # FUNCTION/CALL: pass needed param set so backward only computes required gradients |
| 110 | # (FUNCTION uses implicit TUPLE gradient or grad_fxn; CALL requires an explicit grad_fxn) |
| 111 | if t0.op in {Ops.FUNCTION, Ops.CALL}: |
| 112 | needed = {i for i, arg in enumerate(t0.src[1:]) if arg in targets or in_target_path.get(arg, False)} |
| 113 | lgrads:tuple[UOp|None, ...]|None = call_gradient(grads[t0], t0, needed) |
| 114 | else: |
| 115 | lgrads = cast(tuple[UOp|None, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0])) |
| 116 | if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...") |
| 117 | assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}" |
| 118 | for k,v in zip(t0.src, lgrads): |
| 119 | if v is None: continue |
| 120 | if k in grads and grads[k].op is not Ops.NOOP: |
| 121 | if v.op is Ops.TUPLE and grads[k].op is Ops.TUPLE: |
| 122 | grads[k] = UOp.maketuple(*(p + n if (p.op is not Ops.NOOP and n.op is not Ops.NOOP) else |
| 123 | n if p.op is Ops.NOOP else p for p, n in zip(grads[k].src, v.src))) |
| 124 | else: grads[k] = grads[k] + v |
| 125 | else: grads[k] = v |
| 126 | if len(forward_metadata:=all_metadata.get(t0, ())): |
| 127 | backward_metadata = tuple(dataclasses.replace(x, backward=True) for x in forward_metadata) |
| 128 | # we add the backward metadata to everything new in the graph |
| 129 | for bw_uop in v.toposort(lambda x: x not in (t0, *t0.src, grads[t0])): |
| 130 | all_metadata[bw_uop] = all_metadata.get(bw_uop, ())+backward_metadata |
| 131 | return grads |
searching dependent graphs…