| 34 | return UOp.group(store_c, store_d).end(i).sink(arg=KernelInfo(name=f"custom_addmul_kernel_{C.numel()}")).simplify() |
| 35 | |
| 36 | def custom_gemm(C:UOp, A:UOp, B:UOp) -> UOp: |
| 37 | assert A.shape[1] == B.shape[0] |
| 38 | i, j, k = UOp.range(C.shape[0], 0), UOp.range(C.shape[1], 1), UOp.range(A.shape[1], 2, axis_type=AxisType.REDUCE) |
| 39 | C = C[i, j].set(0.0) |
| 40 | C = C[i, j].set(C.after(k)[i, j] + A[i, k] * B[k, j], end=k) |
| 41 | prog = C.end(i, j) |
| 42 | return prog.sink(arg=KernelInfo(name=f"custom_gemm_{C.shape[0]}_{C.shape[1]}_{A.shape[1]}", opts_to_apply=())) |
| 43 | |
| 44 | def custom_sum(B:UOp, A:UOp) -> UOp: |
| 45 | i = UOp.range(A.shape[0], 0, axis_type=AxisType.REDUCE) |