(self, x, weights)
| 148 | |
| 149 | @I.pyfunc |
| 150 | def forward(self, x, weights): |
| 151 | # Inspect input |
| 152 | print(f" [DEBUG] input shape: {x.shape}, mean: {x.mean():.4f}") |
| 153 | |
| 154 | # Run TIR matmul |
| 155 | x_tvm = self._convert_pytorch_to_tvm(x) |
| 156 | w_tvm = self._convert_pytorch_to_tvm(weights) |
| 157 | out = self.call_tir( |
| 158 | self.matmul_tir, |
| 159 | [x_tvm, w_tvm], |
| 160 | out_sinfo=R.Tensor((x.shape[0], 3), "float32"), |
| 161 | ) |
| 162 | logits = self._convert_tvm_to_pytorch(out) |
| 163 | |
| 164 | # Inspect intermediate value — impossible with a compiled-only workflow |
| 165 | print( |
| 166 | f" [DEBUG] logits shape: {logits.shape}, " |
| 167 | f"min: {logits.min():.4f}, max: {logits.max():.4f}" |
| 168 | ) |
| 169 | |
| 170 | result = F.softmax(logits, dim=-1) |
| 171 | |
| 172 | # Verify output |
| 173 | print(f" [DEBUG] probs sum: {result.sum(dim=-1)}") |
| 174 | return result |
| 175 | |
| 176 | mod = DebugModule(device=tvm.cpu(0)) |
| 177 |
no test coverage detected