| 211 | |
| 212 | @I.ir_module |
| 213 | class PipelineModule(BasePyModule): |
| 214 | @T.prim_func(s_tir=True) |
| 215 | def matmul_tir(var_A: T.handle, var_B: T.handle, var_C: T.handle): |
| 216 | A = T.match_buffer(var_A, (2, 4), "float32") |
| 217 | B = T.match_buffer(var_B, (4, 3), "float32") |
| 218 | C = T.match_buffer(var_C, (2, 3), "float32") |
| 219 | for i, j, k in T.grid(2, 3, 4): |
| 220 | with T.sblock("matmul"): |
| 221 | vi, vj, vk = T.axis.remap("SSR", [i, j, k]) |
| 222 | with T.init(): |
| 223 | C[vi, vj] = T.float32(0) |
| 224 | C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] |
| 225 | |
| 226 | @I.pyfunc |
| 227 | def forward(self, x, weights, bias): |
| 228 | # 1. TIR matmul |
| 229 | x_tvm = self._convert_pytorch_to_tvm(x) |
| 230 | w_tvm = self._convert_pytorch_to_tvm(weights) |
| 231 | h = self.call_tir( |
| 232 | self.matmul_tir, |
| 233 | [x_tvm, w_tvm], |
| 234 | out_sinfo=R.Tensor((2, 3), "float32"), |
| 235 | ) |
| 236 | h_pt = self._convert_tvm_to_pytorch(h) |
| 237 | |
| 238 | # 2. Packed function for bias add (simulating an external library) |
| 239 | h_biased = self.call_dps_packed( |
| 240 | "my_bias_add", |
| 241 | [h_pt, bias], |
| 242 | out_sinfo=R.Tensor((2, 3), "float32"), |
| 243 | ) |
| 244 | |
| 245 | # 3. Python/PyTorch activation |
| 246 | return F.relu(h_biased) |
| 247 | |
| 248 | mod = PipelineModule(device=tvm.cpu(0)) |
| 249 |
no outgoing calls
no test coverage detected
searching dependent graphs…