Import a module on the cache directory for modules and extract a class from it.
(class_name, module_path, force_reload=False)
| 183 | |
| 184 | |
| 185 | def get_class_in_module(class_name, module_path, force_reload=False): |
| 186 | """ |
| 187 | Import a module on the cache directory for modules and extract a class from it. |
| 188 | """ |
| 189 | name = os.path.normpath(module_path) |
| 190 | if name.endswith(".py"): |
| 191 | name = name[:-3] |
| 192 | name = name.replace(os.path.sep, ".") |
| 193 | module_file: Path = Path(HF_MODULES_CACHE) / module_path |
| 194 | |
| 195 | with _HF_REMOTE_CODE_LOCK: |
| 196 | if force_reload: |
| 197 | sys.modules.pop(name, None) |
| 198 | importlib.invalidate_caches() |
| 199 | cached_module: ModuleType | None = sys.modules.get(name) |
| 200 | module_spec = importlib.util.spec_from_file_location(name, location=module_file) |
| 201 | |
| 202 | module: ModuleType |
| 203 | if cached_module is None: |
| 204 | module = importlib.util.module_from_spec(module_spec) |
| 205 | # insert it into sys.modules before any loading begins |
| 206 | sys.modules[name] = module |
| 207 | else: |
| 208 | module = cached_module |
| 209 | |
| 210 | module_spec.loader.exec_module(module) |
| 211 | |
| 212 | if class_name is None: |
| 213 | return find_pipeline_class(module) |
| 214 | |
| 215 | return getattr(module, class_name) |
| 216 | |
| 217 | |
| 218 | def find_pipeline_class(loaded_module): |
no test coverage detected
searching dependent graphs…