MCPcopy Index your code
hub / github.com/apache/tvm / DebugModule

Class DebugModule

docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py:135–174  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

133
134 @I.ir_module
135 class DebugModule(BasePyModule):
136 @T.prim_func(s_tir=True)
137 def matmul_tir(var_A: T.handle, var_B: T.handle, var_C: T.handle):
138 n = T.int32()
139 A = T.match_buffer(var_A, (n, 4), "float32")
140 B = T.match_buffer(var_B, (4, 3), "float32")
141 C = T.match_buffer(var_C, (n, 3), "float32")
142 for i, j, k in T.grid(n, 3, 4):
143 with T.sblock("matmul"):
144 vi, vj, vk = T.axis.remap("SSR", [i, j, k])
145 with T.init():
146 C[vi, vj] = T.float32(0)
147 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
148
149 @I.pyfunc
150 def forward(self, x, weights):
151 # Inspect input
152 print(f" [DEBUG] input shape: {x.shape}, mean: {x.mean():.4f}")
153
154 # Run TIR matmul
155 x_tvm = self._convert_pytorch_to_tvm(x)
156 w_tvm = self._convert_pytorch_to_tvm(weights)
157 out = self.call_tir(
158 self.matmul_tir,
159 [x_tvm, w_tvm],
160 out_sinfo=R.Tensor((x.shape[0], 3), "float32"),
161 )
162 logits = self._convert_tvm_to_pytorch(out)
163
164 # Inspect intermediate value — impossible with a compiled-only workflow
165 print(
166 f" [DEBUG] logits shape: {logits.shape}, "
167 f"min: {logits.min():.4f}, max: {logits.max():.4f}"
168 )
169
170 result = F.softmax(logits, dim=-1)
171
172 # Verify output
173 print(f" [DEBUG] probs sum: {result.sum(dim=-1)}")
174 return result
175
176 mod = DebugModule(device=tvm.cpu(0))
177

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…