(C:UOp, D:UOp, A:UOp, B:UOp)
| 26 | return C[i].store(A[i]+B[i]).end(i).sink(arg=KernelInfo(name=f"custom_add_kernel_{C.numel()}")).simplify() |
| 27 | |
| 28 | def custom_elementwise_addmul_kernel(C:UOp, D:UOp, A:UOp, B:UOp) -> UOp: |
| 29 | C,D,A,B = C.flatten(), D.flatten(), A.flatten(), B.flatten() |
| 30 | assert C.numel() == D.numel() |
| 31 | i = UOp.range(C.numel(), 0) |
| 32 | store_c = C[i].store(A[i]+B[i]) |
| 33 | store_d = D[i].store(A[i]*B[i]) |
| 34 | return UOp.group(store_c, store_d).end(i).sink(arg=KernelInfo(name=f"custom_addmul_kernel_{C.numel()}")).simplify() |
| 35 | |
| 36 | def custom_gemm(C:UOp, A:UOp, B:UOp) -> UOp: |
| 37 | assert A.shape[1] == B.shape[0] |
nothing calls this directly
no test coverage detected
searching dependent graphs…