MCPcopy Index your code
hub / github.com/apache/tvm / call_dps_packed

Method call_dps_packed

python/tvm/relax/base_py_module.py:256–275  ·  view source on GitHub ↗

Call a packed function with PyTorch tensors, converting TVM Tensors via DLPack.

(self, func_name: str, args, out_sinfo)

Source from the content-addressed store, hash-verified

254 return result[0] if len(result) == 1 else result
255
256 def call_dps_packed(self, func_name: str, args, out_sinfo):
257 """Call a packed function with PyTorch tensors, converting TVM Tensors via DLPack."""
258 if hasattr(self, func_name) and callable(getattr(self, func_name)):
259 return getattr(self, func_name)(*args)
260
261 if func_name not in self.extern_funcs:
262 try:
263 self.extern_funcs[func_name] = tvm.get_global_func(func_name)
264 except ValueError as error:
265 raise ValueError(
266 f"Function '{func_name}' not found as a global function. "
267 f"Please implement it as a method or register it."
268 ) from error
269 func = self.extern_funcs[func_name]
270
271 out = self._create_output_tensors(out_sinfo, args)
272 tvm_args = self._convert_pytorch_to_tvm(args)
273 tvm_out = self._convert_pytorch_to_tvm(out)
274 func(*tvm_args, *tvm_out)
275 return out[0] if len(out) == 1 else out
276
277 def call_py_func(self, func_name: str, args):
278 """Call a Python function stored in the module's pyfuncs."""

Calls 4

get_global_funcMethod · 0.80
funcFunction · 0.50