Sets the `lora_state_dict` into `text_encoder` coming from `transformers`. Args: lora_state_dict: The state dictionary to be set. prefix: String identifier to retrieve the portion of the state dict that belongs to `text_encoder`. text_encoder: Where the `lora_state_
(
lora_state_dict: dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module
)
| 331 | |
| 332 | |
| 333 | def _set_state_dict_into_text_encoder( |
| 334 | lora_state_dict: dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module |
| 335 | ): |
| 336 | """ |
| 337 | Sets the `lora_state_dict` into `text_encoder` coming from `transformers`. |
| 338 | |
| 339 | Args: |
| 340 | lora_state_dict: The state dictionary to be set. |
| 341 | prefix: String identifier to retrieve the portion of the state dict that belongs to `text_encoder`. |
| 342 | text_encoder: Where the `lora_state_dict` is to be set. |
| 343 | """ |
| 344 | |
| 345 | text_encoder_state_dict = { |
| 346 | f"{k.replace(prefix, '')}": v for k, v in lora_state_dict.items() if k.startswith(prefix) |
| 347 | } |
| 348 | text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict)) |
| 349 | set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default") |
| 350 | |
| 351 | |
| 352 | def _collate_lora_metadata(modules_to_save: dict[str, torch.nn.Module]) -> dict[str, Any]: |
no test coverage detected
searching dependent graphs…