Call a packed function with PyTorch tensors, converting TVM Tensors via DLPack.
(self, func_name: str, args, out_sinfo)
| 254 | return result[0] if len(result) == 1 else result |
| 255 | |
| 256 | def call_dps_packed(self, func_name: str, args, out_sinfo): |
| 257 | """Call a packed function with PyTorch tensors, converting TVM Tensors via DLPack.""" |
| 258 | if hasattr(self, func_name) and callable(getattr(self, func_name)): |
| 259 | return getattr(self, func_name)(*args) |
| 260 | |
| 261 | if func_name not in self.extern_funcs: |
| 262 | try: |
| 263 | self.extern_funcs[func_name] = tvm.get_global_func(func_name) |
| 264 | except ValueError as error: |
| 265 | raise ValueError( |
| 266 | f"Function '{func_name}' not found as a global function. " |
| 267 | f"Please implement it as a method or register it." |
| 268 | ) from error |
| 269 | func = self.extern_funcs[func_name] |
| 270 | |
| 271 | out = self._create_output_tensors(out_sinfo, args) |
| 272 | tvm_args = self._convert_pytorch_to_tvm(args) |
| 273 | tvm_out = self._convert_pytorch_to_tvm(out) |
| 274 | func(*tvm_args, *tvm_out) |
| 275 | return out[0] if len(out) == 1 else out |
| 276 | |
| 277 | def call_py_func(self, func_name: str, args): |
| 278 | """Call a Python function stored in the module's pyfuncs.""" |