| 33 | class InputModule: |
| 34 | @T.prim_func(s_tir=True) |
| 35 | def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: |
| 36 | T.func_attr({"global_symbol": "tir_matmul"}) |
| 37 | A = T.match_buffer(x, (16, 16)) |
| 38 | B = T.match_buffer(y, (16, 16)) |
| 39 | C = T.match_buffer(z, (16, 16)) |
| 40 | for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4): |
| 41 | with T.sblock("matmul"): |
| 42 | vi = T.axis.S(16, i0 * 4 + i1) |
| 43 | vj = T.axis.S(16, j) |
| 44 | vk = T.axis.R(16, k0 * 4 + k1) |
| 45 | with T.init(): |
| 46 | C[vi, vj] = T.float32(0) |
| 47 | C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] |
| 48 | |
| 49 | @R.function |
| 50 | def main(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")) -> R.Tensor( |