(self, file_path, model_names=None, device=None, torch_dtype=None)
| 393 | |
| 394 | |
| 395 | def load_model(self, file_path, model_names=None, device=None, torch_dtype=None): |
| 396 | print(f"Loading models from: {file_path}") |
| 397 | if device is None: device = self.device |
| 398 | if torch_dtype is None: torch_dtype = self.torch_dtype |
| 399 | if isinstance(file_path, list): |
| 400 | state_dict = {} |
| 401 | for path in file_path: |
| 402 | state_dict.update(load_state_dict(path)) |
| 403 | elif os.path.isfile(file_path): |
| 404 | state_dict = load_state_dict(file_path) |
| 405 | else: |
| 406 | state_dict = None |
| 407 | for model_detector in self.model_detector: |
| 408 | if model_detector.match(file_path, state_dict): |
| 409 | model_names, models = model_detector.load( |
| 410 | file_path, state_dict, |
| 411 | device=device, torch_dtype=torch_dtype, |
| 412 | allowed_model_names=model_names, model_manager=self |
| 413 | ) |
| 414 | for model_name, model in zip(model_names, models): |
| 415 | self.model.append(model) |
| 416 | self.model_path.append(file_path) |
| 417 | self.model_name.append(model_name) |
| 418 | print(f" The following models are loaded: {model_names}.") |
| 419 | break |
| 420 | else: |
| 421 | print(f" We cannot detect the model type. No models are loaded.") |
| 422 | |
| 423 | |
| 424 | def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None): |
no test coverage detected