Call a TIR function with PyTorch tensors.
(self, tir_func, args, out_sinfo)
| 219 | register_py_func(func_name, wrapped_func) |
| 220 | |
| 221 | def call_tir(self, tir_func, args, out_sinfo): |
| 222 | """Call a TIR function with PyTorch tensors.""" |
| 223 | # Try to get function name from different sources |
| 224 | if isinstance(tir_func, str): |
| 225 | func_name = tir_func |
| 226 | elif hasattr(tir_func, "name"): |
| 227 | func_name = tir_func.name |
| 228 | elif hasattr(tir_func, "__name__"): |
| 229 | func_name = tir_func.__name__ |
| 230 | else: |
| 231 | # Try to find by function object reference |
| 232 | for name, func in self.compiled_tir_funcs.items(): |
| 233 | if func == tir_func: |
| 234 | func_name = name |
| 235 | break |
| 236 | else: |
| 237 | func_name = None |
| 238 | |
| 239 | if not func_name or func_name not in self.compiled_tir_funcs: |
| 240 | available_funcs = list(self.compiled_tir_funcs.keys()) |
| 241 | raise ValueError( |
| 242 | f"Could not resolve or find compiled TIR function: {tir_func}. " |
| 243 | f"Available functions: {available_funcs}" |
| 244 | ) |
| 245 | func = self.compiled_tir_funcs[func_name] |
| 246 | |
| 247 | out = self._create_output_tensors(out_sinfo, args) |
| 248 | tvm_args = self._convert_pytorch_to_tvm(args) |
| 249 | tvm_out = self._convert_pytorch_to_tvm(out) |
| 250 | |
| 251 | func(*tvm_args, *tvm_out) |
| 252 | |
| 253 | result = self._convert_tvm_to_pytorch(tvm_out) |
| 254 | return result[0] if len(result) == 1 else result |
| 255 | |
| 256 | def call_dps_packed(self, func_name: str, args, out_sinfo): |
| 257 | """Call a packed function with PyTorch tensors, converting TVM Tensors via DLPack.""" |