(ctx:UOp, ret:UOp, op:Ops)
| 5 | from tinygrad.dtype import sum_acc_dtype |
| 6 | |
| 7 | def reduce_gradient(ctx:UOp, ret:UOp, op:Ops): |
| 8 | def broadcast_to_input(x): return x.reshape(x.shape+(1,)*(len(ret.src[0].shape)-len(x.shape))).expand(ret.src[0].shape) |
| 9 | if op == Ops.ADD: return (broadcast_to_input(ctx),) |
| 10 | if op == Ops.MAX: |
| 11 | assert ret.op is Ops.REDUCE, "only works on REDUCE" |
| 12 | mask = ret.src[0].eq(broadcast_to_input(ret)).cast(ctx.dtype) |
| 13 | count = mask._rop(Ops.ADD, ret.arg[1]) |
| 14 | return ((mask/broadcast_to_input(count)) * broadcast_to_input(ctx),) |
| 15 | if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],) |
| 16 | |
| 17 | def _compact_params(body:UOp, all_args:tuple[UOp, ...]) -> tuple[UOp, tuple[UOp, ...]]: |
| 18 | """Remove unused PARAMs from body and return compacted (body, args).""" |
no test coverage detected
searching dependent graphs…