MCPcopy Index your code
hub / github.com/apache/tvm / call_tir

Method call_tir

python/tvm/relax/base_py_module.py:221–254  ·  view source on GitHub ↗

Call a TIR function with PyTorch tensors.

(self, tir_func, args, out_sinfo)

Source from the content-addressed store, hash-verified

219 register_py_func(func_name, wrapped_func)
220
221 def call_tir(self, tir_func, args, out_sinfo):
222 """Call a TIR function with PyTorch tensors."""
223 # Try to get function name from different sources
224 if isinstance(tir_func, str):
225 func_name = tir_func
226 elif hasattr(tir_func, "name"):
227 func_name = tir_func.name
228 elif hasattr(tir_func, "__name__"):
229 func_name = tir_func.__name__
230 else:
231 # Try to find by function object reference
232 for name, func in self.compiled_tir_funcs.items():
233 if func == tir_func:
234 func_name = name
235 break
236 else:
237 func_name = None
238
239 if not func_name or func_name not in self.compiled_tir_funcs:
240 available_funcs = list(self.compiled_tir_funcs.keys())
241 raise ValueError(
242 f"Could not resolve or find compiled TIR function: {tir_func}. "
243 f"Available functions: {available_funcs}"
244 )
245 func = self.compiled_tir_funcs[func_name]
246
247 out = self._create_output_tensors(out_sinfo, args)
248 tvm_args = self._convert_pytorch_to_tvm(args)
249 tvm_out = self._convert_pytorch_to_tvm(out)
250
251 func(*tvm_args, *tvm_out)
252
253 result = self._convert_tvm_to_pytorch(tvm_out)
254 return result[0] if len(result) == 1 else result
255
256 def call_dps_packed(self, func_name: str, args, out_sinfo):
257 """Call a packed function with PyTorch tensors, converting TVM Tensors via DLPack."""

Callers 15

visit_call_Method · 0.80
tensor_ir_opFunction · 0.80
visit_call_Method · 0.80
visit_call_Method · 0.80
call_tirFunction · 0.80
__call__Method · 0.80
visit_expr_stmtFunction · 0.80

Calls 6

funcFunction · 0.50
itemsMethod · 0.45
keysMethod · 0.45

Tested by 15

mainMethod · 0.64
mainMethod · 0.64
beforeMethod · 0.64
expectedMethod · 0.64
mainMethod · 0.64
mainMethod · 0.64
mainMethod · 0.64