Base class that allows Python functions in IRModule with DLPack conversion. This class provides the infrastructure for: 1. JIT compilation of TIR and Relax functions. 2. DLPack-based conversion between PyTorch tensors and TVM Tensors. 3. Wrapping Relax functions for easy Python call
| 44 | |
| 45 | |
| 46 | class BasePyModule: |
| 47 | """Base class that allows Python functions in IRModule with DLPack conversion. |
| 48 | |
| 49 | This class provides the infrastructure for: |
| 50 | 1. JIT compilation of TIR and Relax functions. |
| 51 | 2. DLPack-based conversion between PyTorch tensors and TVM Tensors. |
| 52 | 3. Wrapping Relax functions for easy Python calling. |
| 53 | 4. Cross-function calls between Python, TIR, and Relax functions. |
| 54 | |
| 55 | Only IRModules that inherit from this class are allowed to contain Python functions. |
| 56 | """ |
| 57 | |
| 58 | def __del__(self): |
| 59 | """Clean up registered Python functions on module destruction.""" |
| 60 | try: |
| 61 | clear_func = tvm.get_global_func("vm.builtin.clear_py_func_registry") |
| 62 | clear_func() |
| 63 | except (ValueError, AttributeError): |
| 64 | pass |
| 65 | |
| 66 | def __init__( |
| 67 | self, |
| 68 | ir_mod: IRModule, |
| 69 | device: Device, |
| 70 | target: Target | None = None, |
| 71 | ): |
| 72 | """Initialize BasePyModule with JIT compilation and DLPack conversion.""" |
| 73 | self.device = device |
| 74 | self.ir_mod = ir_mod |
| 75 | |
| 76 | # Delegate IRModule operations |
| 77 | self.functions = ir_mod.functions |
| 78 | self.attrs = ir_mod.attrs |
| 79 | self.global_infos = ir_mod.global_infos |
| 80 | self.__getitem__ = ir_mod.__getitem__ |
| 81 | self.__setitem__ = ir_mod.__setitem__ |
| 82 | self.functions_items = ir_mod.functions_items |
| 83 | self.with_attr = ir_mod.with_attr |
| 84 | self.get_attr = ir_mod.get_attr |
| 85 | self.update_global_info = ir_mod.update_global_info |
| 86 | |
| 87 | def _getattr_python_function(name: str) -> Any: |
| 88 | """Support direct attribute access to funcs and IRModule methods.""" |
| 89 | if name in self.pyfuncs: |
| 90 | return self.pyfuncs[name] |
| 91 | if name in self.compiled_tir_funcs: |
| 92 | return self.compiled_tir_funcs[name] |
| 93 | if self.relax_vm and name in self.relax_func_names: |
| 94 | try: |
| 95 | return self.relax_vm[name] |
| 96 | except AttributeError: # More specific exception |
| 97 | return None |
| 98 | if hasattr(self.ir_mod, name): |
| 99 | return getattr(self.ir_mod, name) |
| 100 | raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") |
| 101 | |
| 102 | self.__getattr__ = _getattr_python_function |
| 103 |
no outgoing calls
searching dependent graphs…