(x)
| 6 | |
| 7 | def reduce_gradient(ctx:UOp, ret:UOp, op:Ops): |
| 8 | def broadcast_to_input(x): return x.reshape(x.shape+(1,)*(len(ret.src[0].shape)-len(x.shape))).expand(ret.src[0].shape) |
| 9 | if op == Ops.ADD: return (broadcast_to_input(ctx),) |
| 10 | if op == Ops.MAX: |
| 11 | assert ret.op is Ops.REDUCE, "only works on REDUCE" |
no test coverage detected
searching dependent graphs…