r""" Args: This function saves the adapter model and the adapter configuration files to a directory, so that it can be re-loaded using the `LoraModel.from_pretrained` class method, and also used by the `LoraModel.push_to_hub` method. save_directory (`str`)
(self, save_directory, **kwargs)
| 83 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 84 | |
| 85 | def save_pretrained(self, save_directory, **kwargs): |
| 86 | r""" |
| 87 | Args: |
| 88 | This function saves the adapter model and the adapter configuration files to a directory, so that it can be |
| 89 | re-loaded using the `LoraModel.from_pretrained` class method, and also used by the `LoraModel.push_to_hub` |
| 90 | method. |
| 91 | save_directory (`str`): |
| 92 | Directory where the adapter model and configuration files will be saved (will be created if it does not |
| 93 | exist). |
| 94 | **kwargs: |
| 95 | Additional keyword arguments passed along to the `push_to_hub` method. |
| 96 | """ |
| 97 | if os.path.isfile(save_directory): |
| 98 | raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file") |
| 99 | os.makedirs(save_directory, exist_ok=True) |
| 100 | |
| 101 | # save only the trainable weights |
| 102 | output_state_dict = get_peft_model_state_dict(self, kwargs.get("state_dict", None)) |
| 103 | torch.save(output_state_dict, os.path.join(save_directory, WEIGHTS_NAME)) |
| 104 | |
| 105 | # save the config and change the inference mode to `True` |
| 106 | if self.peft_config.base_model_name_or_path is None: |
| 107 | self.peft_config.base_model_name_or_path = ( |
| 108 | self.base_model.__dict__.get("name_or_path", None) |
| 109 | if isinstance(self.peft_config, PromptLearningConfig) |
| 110 | else self.base_model.model.__dict__.get("name_or_path", None) |
| 111 | ) |
| 112 | inference_mode = self.peft_config.inference_mode |
| 113 | self.peft_config.inference_mode = True |
| 114 | self.peft_config.save_pretrained(save_directory) |
| 115 | self.peft_config.inference_mode = inference_mode |
| 116 | |
| 117 | @classmethod |
| 118 | def from_pretrained(cls, model, model_id, **kwargs): |
no test coverage detected