| 314 | |
| 315 | |
| 316 | class ModelManager: |
| 317 | def __init__( |
| 318 | self, |
| 319 | torch_dtype=torch.float16, |
| 320 | device="cuda", |
| 321 | model_id_list: List[Preset_model_id] = [], |
| 322 | downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"], |
| 323 | file_path_list: List[str] = [], |
| 324 | ): |
| 325 | self.torch_dtype = torch_dtype |
| 326 | self.device = device |
| 327 | self.model = [] |
| 328 | self.model_path = [] |
| 329 | self.model_name = [] |
| 330 | downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else [] |
| 331 | self.model_detector = [ |
| 332 | ModelDetectorFromSingleFile(model_loader_configs), |
| 333 | ModelDetectorFromSplitedSingleFile(model_loader_configs), |
| 334 | ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs), |
| 335 | ModelDetectorFromPatchedSingleFile(patch_model_loader_configs), |
| 336 | ] |
| 337 | self.load_models(downloaded_files + file_path_list) |
| 338 | |
| 339 | |
| 340 | def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None): |
| 341 | print(f"Loading models from file: {file_path}") |
| 342 | if len(state_dict) == 0: |
| 343 | state_dict = load_state_dict(file_path) |
| 344 | model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device) |
| 345 | for model_name, model in zip(model_names, models): |
| 346 | self.model.append(model) |
| 347 | self.model_path.append(file_path) |
| 348 | self.model_name.append(model_name) |
| 349 | print(f" The following models are loaded: {model_names}.") |
| 350 | |
| 351 | |
| 352 | def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]): |
| 353 | print(f"Loading models from folder: {file_path}") |
| 354 | model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device) |
| 355 | for model_name, model in zip(model_names, models): |
| 356 | self.model.append(model) |
| 357 | self.model_path.append(file_path) |
| 358 | self.model_name.append(model_name) |
| 359 | print(f" The following models are loaded: {model_names}.") |
| 360 | |
| 361 | |
| 362 | def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}): |
| 363 | print(f"Loading patch models from file: {file_path}") |
| 364 | model_names, models = load_patch_model_from_single_file( |
| 365 | state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device) |
| 366 | for model_name, model in zip(model_names, models): |
| 367 | self.model.append(model) |
| 368 | self.model_path.append(file_path) |
| 369 | self.model_name.append(model_name) |
| 370 | print(f" The following patched models are loaded: {model_names}.") |
| 371 | |
| 372 | |
| 373 | def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0): |
no outgoing calls