(
state_dict,
network_alphas,
text_encoder,
prefix=None,
lora_scale=1.0,
text_encoder_name="text_encoder",
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
)
| 319 | |
| 320 | |
| 321 | def _load_lora_into_text_encoder( |
| 322 | state_dict, |
| 323 | network_alphas, |
| 324 | text_encoder, |
| 325 | prefix=None, |
| 326 | lora_scale=1.0, |
| 327 | text_encoder_name="text_encoder", |
| 328 | adapter_name=None, |
| 329 | _pipeline=None, |
| 330 | low_cpu_mem_usage=False, |
| 331 | hotswap: bool = False, |
| 332 | metadata=None, |
| 333 | ): |
| 334 | from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading |
| 335 | |
| 336 | if not USE_PEFT_BACKEND: |
| 337 | raise ValueError("PEFT backend is required for this method.") |
| 338 | |
| 339 | if network_alphas and metadata: |
| 340 | raise ValueError("`network_alphas` and `metadata` cannot be specified both at the same time.") |
| 341 | |
| 342 | peft_kwargs = {} |
| 343 | if low_cpu_mem_usage: |
| 344 | if not is_peft_version(">=", "0.13.1"): |
| 345 | raise ValueError( |
| 346 | "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." |
| 347 | ) |
| 348 | if not is_transformers_version(">", "4.45.2"): |
| 349 | # Note from sayakpaul: It's not in `transformers` stable yet. |
| 350 | # https://github.com/huggingface/transformers/pull/33725/ |
| 351 | raise ValueError( |
| 352 | "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." |
| 353 | ) |
| 354 | peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage |
| 355 | |
| 356 | # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), |
| 357 | # then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as |
| 358 | # their prefixes. |
| 359 | prefix = text_encoder_name if prefix is None else prefix |
| 360 | |
| 361 | # Safe prefix to check with. |
| 362 | if hotswap and any(text_encoder_name in key for key in state_dict.keys()): |
| 363 | raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.") |
| 364 | |
| 365 | # Load the layers corresponding to text encoder and make necessary adjustments. |
| 366 | if prefix is not None: |
| 367 | state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} |
| 368 | if metadata is not None: |
| 369 | metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")} |
| 370 | |
| 371 | if len(state_dict) > 0: |
| 372 | logger.info(f"Loading {prefix}.") |
| 373 | rank = {} |
| 374 | state_dict = convert_state_dict_to_diffusers(state_dict) |
| 375 | |
| 376 | # convert state dict |
| 377 | state_dict = convert_state_dict_to_peft(state_dict) |
| 378 |
no test coverage detected
searching dependent graphs…