MCPcopy Index your code
hub / github.com/apache/tvm / PipelineModule

Class PipelineModule

docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py:213–246  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

211
212 @I.ir_module
213 class PipelineModule(BasePyModule):
214 @T.prim_func(s_tir=True)
215 def matmul_tir(var_A: T.handle, var_B: T.handle, var_C: T.handle):
216 A = T.match_buffer(var_A, (2, 4), "float32")
217 B = T.match_buffer(var_B, (4, 3), "float32")
218 C = T.match_buffer(var_C, (2, 3), "float32")
219 for i, j, k in T.grid(2, 3, 4):
220 with T.sblock("matmul"):
221 vi, vj, vk = T.axis.remap("SSR", [i, j, k])
222 with T.init():
223 C[vi, vj] = T.float32(0)
224 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
225
226 @I.pyfunc
227 def forward(self, x, weights, bias):
228 # 1. TIR matmul
229 x_tvm = self._convert_pytorch_to_tvm(x)
230 w_tvm = self._convert_pytorch_to_tvm(weights)
231 h = self.call_tir(
232 self.matmul_tir,
233 [x_tvm, w_tvm],
234 out_sinfo=R.Tensor((2, 3), "float32"),
235 )
236 h_pt = self._convert_tvm_to_pytorch(h)
237
238 # 2. Packed function for bias add (simulating an external library)
239 h_biased = self.call_dps_packed(
240 "my_bias_add",
241 [h_pt, bias],
242 out_sinfo=R.Tensor((2, 3), "float32"),
243 )
244
245 # 3. Python/PyTorch activation
246 return F.relu(h_biased)
247
248 mod = PipelineModule(device=tvm.cpu(0))
249

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…