(B:UOp, A:UOp)
| 15 | return C[i, j].store((i.eq(j)).cast(C.dtype.base)).end(i, j).sink(arg=KernelInfo(name=f"custom_eye_{C.numel()}")) |
| 16 | |
| 17 | def custom_add_one_kernel(B:UOp, A:UOp) -> UOp: |
| 18 | A,B = A.flatten(), B.flatten() |
| 19 | assert B.numel() == A.numel() |
| 20 | i = UOp.range(A.numel(), 0) |
| 21 | return B[i].store(A[i] + 1).end(i).sink(arg=KernelInfo(name=f"add_one_{A.numel()}")) |
| 22 | |
| 23 | def custom_elementwise_add_kernel(C:UOp, A:UOp, B:UOp) -> UOp: |
| 24 | C,A,B = C.flatten(), A.flatten(), B.flatten() |