| 133 | |
| 134 | @I.ir_module |
| 135 | class DebugModule(BasePyModule): |
| 136 | @T.prim_func(s_tir=True) |
| 137 | def matmul_tir(var_A: T.handle, var_B: T.handle, var_C: T.handle): |
| 138 | n = T.int32() |
| 139 | A = T.match_buffer(var_A, (n, 4), "float32") |
| 140 | B = T.match_buffer(var_B, (4, 3), "float32") |
| 141 | C = T.match_buffer(var_C, (n, 3), "float32") |
| 142 | for i, j, k in T.grid(n, 3, 4): |
| 143 | with T.sblock("matmul"): |
| 144 | vi, vj, vk = T.axis.remap("SSR", [i, j, k]) |
| 145 | with T.init(): |
| 146 | C[vi, vj] = T.float32(0) |
| 147 | C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] |
| 148 | |
| 149 | @I.pyfunc |
| 150 | def forward(self, x, weights): |
| 151 | # Inspect input |
| 152 | print(f" [DEBUG] input shape: {x.shape}, mean: {x.mean():.4f}") |
| 153 | |
| 154 | # Run TIR matmul |
| 155 | x_tvm = self._convert_pytorch_to_tvm(x) |
| 156 | w_tvm = self._convert_pytorch_to_tvm(weights) |
| 157 | out = self.call_tir( |
| 158 | self.matmul_tir, |
| 159 | [x_tvm, w_tvm], |
| 160 | out_sinfo=R.Tensor((x.shape[0], 3), "float32"), |
| 161 | ) |
| 162 | logits = self._convert_tvm_to_pytorch(out) |
| 163 | |
| 164 | # Inspect intermediate value — impossible with a compiled-only workflow |
| 165 | print( |
| 166 | f" [DEBUG] logits shape: {logits.shape}, " |
| 167 | f"min: {logits.min():.4f}, max: {logits.max():.4f}" |
| 168 | ) |
| 169 | |
| 170 | result = F.softmax(logits, dim=-1) |
| 171 | |
| 172 | # Verify output |
| 173 | print(f" [DEBUG] probs sum: {result.sum(dim=-1)}") |
| 174 | return result |
| 175 | |
| 176 | mod = DebugModule(device=tvm.cpu(0)) |
| 177 |
no outgoing calls
no test coverage detected
searching dependent graphs…