MCPcopy
hub / github.com/tinygrad/tinygrad / reduce_gradient

Function reduce_gradient

tinygrad/gradient.py:7–15  ·  view source on GitHub ↗
(ctx:UOp, ret:UOp, op:Ops)

Source from the content-addressed store, hash-verified

5from tinygrad.dtype import sum_acc_dtype
6
7def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
8 def broadcast_to_input(x): return x.reshape(x.shape+(1,)*(len(ret.src[0].shape)-len(x.shape))).expand(ret.src[0].shape)
9 if op == Ops.ADD: return (broadcast_to_input(ctx),)
10 if op == Ops.MAX:
11 assert ret.op is Ops.REDUCE, "only works on REDUCE"
12 mask = ret.src[0].eq(broadcast_to_input(ret)).cast(ctx.dtype)
13 count = mask._rop(Ops.ADD, ret.arg[1])
14 return ((mask/broadcast_to_input(count)) * broadcast_to_input(ctx),)
15 if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],)
16
17def _compact_params(body:UOp, all_args:tuple[UOp, ...]) -> tuple[UOp, tuple[UOp, ...]]:
18 """Remove unused PARAMs from body and return compacted (body, args)."""

Callers 1

gradient.pyFile · 0.85

Calls 4

broadcast_to_inputFunction · 0.85
eqMethod · 0.80
castMethod · 0.45
_ropMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…