TIR function for matrix multiplication.
(
var_A: T.handle,
var_B: T.handle,
var_C: T.handle,
)
| 64 | |
| 65 | @T.prim_func(s_tir=True) |
| 66 | def matmul( |
| 67 | var_A: T.handle, |
| 68 | var_B: T.handle, |
| 69 | var_C: T.handle, |
| 70 | ): |
| 71 | """TIR function for matrix multiplication.""" |
| 72 | n = T.int32() |
| 73 | A = T.match_buffer(var_A, (n, 16), "float32") |
| 74 | B = T.match_buffer(var_B, (16, 20), "float32") |
| 75 | C = T.match_buffer(var_C, (n, 20), "float32") |
| 76 | |
| 77 | for i, j, k in T.grid(n, 20, 16): |
| 78 | with T.sblock("block"): |
| 79 | vi, vj, vk = T.axis.remap("SSR", [i, j, k]) |
| 80 | with T.init(): |
| 81 | C[vi, vj] = T.float32(0) |
| 82 | C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] |
| 83 | |
| 84 | @I.pyfunc |
| 85 | def my_identity_func(self, x: torch.Tensor) -> torch.Tensor: |
no test coverage detected