(
state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None
)
| 345 | |
| 346 | |
| 347 | def _create_lora_config( |
| 348 | state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None |
| 349 | ): |
| 350 | from peft import LoraConfig |
| 351 | |
| 352 | if metadata is not None: |
| 353 | lora_config_kwargs = metadata |
| 354 | else: |
| 355 | lora_config_kwargs = get_peft_kwargs( |
| 356 | rank_pattern_dict, |
| 357 | network_alpha_dict=network_alphas, |
| 358 | peft_state_dict=state_dict, |
| 359 | is_unet=is_unet, |
| 360 | model_state_dict=model_state_dict, |
| 361 | adapter_name=adapter_name, |
| 362 | ) |
| 363 | |
| 364 | _maybe_raise_error_for_ambiguous_keys(lora_config_kwargs) |
| 365 | |
| 366 | # Version checks for DoRA and lora_bias |
| 367 | if "use_dora" in lora_config_kwargs and lora_config_kwargs["use_dora"]: |
| 368 | if is_peft_version("<", "0.9.0"): |
| 369 | raise ValueError("DoRA requires PEFT >= 0.9.0. Please upgrade.") |
| 370 | |
| 371 | if "lora_bias" in lora_config_kwargs and lora_config_kwargs["lora_bias"]: |
| 372 | if is_peft_version("<=", "0.13.2"): |
| 373 | raise ValueError("lora_bias requires PEFT >= 0.14.0. Please upgrade.") |
| 374 | |
| 375 | try: |
| 376 | return LoraConfig(**lora_config_kwargs) |
| 377 | except TypeError as e: |
| 378 | raise TypeError("`LoraConfig` class could not be instantiated.") from e |
| 379 | |
| 380 | |
| 381 | def _maybe_raise_error_for_ambiguous_keys(config): |
no test coverage detected
searching dependent graphs…