MCPcopy
hub / github.com/ymcui/Chinese-LLaMA-Alpaca-2 / save_pretrained

Method save_pretrained

scripts/training/peft/peft_model.py:85–115  ·  view source on GitHub ↗

r""" Args: This function saves the adapter model and the adapter configuration files to a directory, so that it can be re-loaded using the `LoraModel.from_pretrained` class method, and also used by the `LoraModel.push_to_hub` method. save_directory (`str`)

(self, save_directory, **kwargs)

Source from the content-addressed store, hash-verified

83 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
84
85 def save_pretrained(self, save_directory, **kwargs):
86 r"""
87 Args:
88 This function saves the adapter model and the adapter configuration files to a directory, so that it can be
89 re-loaded using the `LoraModel.from_pretrained` class method, and also used by the `LoraModel.push_to_hub`
90 method.
91 save_directory (`str`):
92 Directory where the adapter model and configuration files will be saved (will be created if it does not
93 exist).
94 **kwargs:
95 Additional keyword arguments passed along to the `push_to_hub` method.
96 """
97 if os.path.isfile(save_directory):
98 raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
99 os.makedirs(save_directory, exist_ok=True)
100
101 # save only the trainable weights
102 output_state_dict = get_peft_model_state_dict(self, kwargs.get("state_dict", None))
103 torch.save(output_state_dict, os.path.join(save_directory, WEIGHTS_NAME))
104
105 # save the config and change the inference mode to `True`
106 if self.peft_config.base_model_name_or_path is None:
107 self.peft_config.base_model_name_or_path = (
108 self.base_model.__dict__.get("name_or_path", None)
109 if isinstance(self.peft_config, PromptLearningConfig)
110 else self.base_model.model.__dict__.get("name_or_path", None)
111 )
112 inference_mode = self.peft_config.inference_mode
113 self.peft_config.inference_mode = True
114 self.peft_config.save_pretrained(save_directory)
115 self.peft_config.inference_mode = inference_mode
116
117 @classmethod
118 def from_pretrained(cls, model, model_id, **kwargs):

Callers 6

inference_hf.pyFile · 0.45
save_modelMethod · 0.45
on_train_endMethod · 0.45
save_modelMethod · 0.45
on_train_endMethod · 0.45

Calls 1

Tested by

no test coverage detected