Returns a Peft model object from a model and a config. Args: model ([`transformers.PreTrainedModel`]): Model to be wrapped. peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model.
(model, peft_config)
| 127 | |
| 128 | |
| 129 | def get_peft_model(model, peft_config): |
| 130 | """ |
| 131 | Returns a Peft model object from a model and a config. |
| 132 | |
| 133 | Args: |
| 134 | model ([`transformers.PreTrainedModel`]): Model to be wrapped. |
| 135 | peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model. |
| 136 | """ |
| 137 | |
| 138 | model_config = model.config.to_dict() |
| 139 | peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None) |
| 140 | if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys(): |
| 141 | peft_config = _prepare_lora_config(peft_config, model_config) |
| 142 | return PeftModel(model, peft_config) |
| 143 | if not isinstance(peft_config, PromptLearningConfig): |
| 144 | peft_config = _prepare_lora_config(peft_config, model_config) |
| 145 | else: |
| 146 | peft_config = _prepare_prompt_learning_config(peft_config, model_config) |
| 147 | return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config) |
no test coverage detected