(self, file_path="", state_dict={}, lora_alpha=1.0)
| 371 | |
| 372 | |
| 373 | def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0): |
| 374 | if isinstance(file_path, list): |
| 375 | for file_path_ in file_path: |
| 376 | self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha) |
| 377 | else: |
| 378 | print(f"Loading LoRA models from file: {file_path}") |
| 379 | is_loaded = False |
| 380 | if len(state_dict) == 0: |
| 381 | state_dict = load_state_dict(file_path) |
| 382 | for model_name, model, model_path in zip(self.model_name, self.model, self.model_path): |
| 383 | for lora in get_lora_loaders(): |
| 384 | match_results = lora.match(model, state_dict) |
| 385 | if match_results is not None: |
| 386 | print(f" Adding LoRA to {model_name} ({model_path}).") |
| 387 | lora_prefix, model_resource = match_results |
| 388 | lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource) |
| 389 | is_loaded = True |
| 390 | break |
| 391 | if not is_loaded: |
| 392 | print(f" Cannot load LoRA: {file_path}") |
| 393 | |
| 394 | |
| 395 | def load_model(self, file_path, model_names=None, device=None, torch_dtype=None): |
no test coverage detected