MCPcopy Index your code
hub / github.com/huggingface/diffusers / unet_lora_state_dict

Function unet_lora_state_dict

src/diffusers/training_utils.py:297–313  ·  view source on GitHub ↗

r""" Returns: A state dict containing just the LoRA parameters.

(unet: UNet2DConditionModel)

Source from the content-addressed store, hash-verified

295
296
297def unet_lora_state_dict(unet: UNet2DConditionModel) -> dict[str, torch.Tensor]:
298 r"""
299 Returns:
300 A state dict containing just the LoRA parameters.
301 """
302 lora_state_dict = {}
303
304 for name, module in unet.named_modules():
305 if hasattr(module, "set_lora_layer"):
306 lora_layer = getattr(module, "lora_layer")
307 if lora_layer is not None:
308 current_lora_layer_sd = lora_layer.state_dict()
309 for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items():
310 # The matrix name can either be "down" or "up".
311 lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param
312
313 return lora_state_dict
314
315
316def cast_training_params(model: torch.nn.Module | list[torch.nn.Module], dtype=torch.float32):

Callers

nothing calls this directly

Calls 1

state_dictMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…