r""" Returns: A state dict containing just the LoRA parameters.
(unet: UNet2DConditionModel)
| 295 | |
| 296 | |
| 297 | def unet_lora_state_dict(unet: UNet2DConditionModel) -> dict[str, torch.Tensor]: |
| 298 | r""" |
| 299 | Returns: |
| 300 | A state dict containing just the LoRA parameters. |
| 301 | """ |
| 302 | lora_state_dict = {} |
| 303 | |
| 304 | for name, module in unet.named_modules(): |
| 305 | if hasattr(module, "set_lora_layer"): |
| 306 | lora_layer = getattr(module, "lora_layer") |
| 307 | if lora_layer is not None: |
| 308 | current_lora_layer_sd = lora_layer.state_dict() |
| 309 | for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items(): |
| 310 | # The matrix name can either be "down" or "up". |
| 311 | lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param |
| 312 | |
| 313 | return lora_state_dict |
| 314 | |
| 315 | |
| 316 | def cast_training_params(model: torch.nn.Module | list[torch.nn.Module], dtype=torch.float32): |
nothing calls this directly
no test coverage detected
searching dependent graphs…