(out:UOp)
| 322 | |
| 323 | # TODO: how do we remove this dumb kernel and use Tensor.zeros? |
| 324 | def _zero_kernel(out:UOp) -> UOp: |
| 325 | i = UOp.range(out.numel(), 0) |
| 326 | return out.flatten()[i].store(0).end(i).sink(arg=KernelInfo(name="zero")) |
| 327 | grad_weight_uop = grad_weight_uop.custom_kernel(fxn=_zero_kernel)[0] |
| 328 | |
| 329 | # TODO: do we have a universal helper for this? |