MCPcopy Index your code
hub / github.com/huggingface/diffusers / convert_ldm_unet_checkpoint

Function convert_ldm_unet_checkpoint

scripts/convert_i2vgen_to_diffusers.py:135–445  ·  view source on GitHub ↗

Takes a state dict and a config, and returns a converted checkpoint.

(checkpoint, config, path=None, extract_ema=False)

Source from the content-addressed store, hash-verified

133
134
135def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
136 """
137 Takes a state dict and a config, and returns a converted checkpoint.
138 """
139
140 # extract state_dict for UNet
141 unet_state_dict = {}
142 keys = list(checkpoint.keys())
143
144 unet_key = "model.diffusion_model."
145
146 # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
147 if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
148 print(f"Checkpoint {path} has both EMA and non-EMA weights.")
149 print(
150 "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
151 " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
152 )
153 for key in keys:
154 if key.startswith("model.diffusion_model"):
155 flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
156 unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
157 else:
158 if sum(k.startswith("model_ema") for k in keys) > 100:
159 print(
160 "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
161 " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
162 )
163
164 for key in keys:
165 unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
166
167 new_checkpoint = {}
168
169 new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
170 new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
171 new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
172 new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
173
174 additional_embedding_substrings = [
175 "local_image_concat",
176 "context_embedding",
177 "local_image_embedding",
178 "fps_embedding",
179 ]
180 for k in unet_state_dict:
181 if any(substring in k for substring in additional_embedding_substrings):
182 diffusers_key = k.replace("local_image_concat", "image_latents_proj_in").replace(
183 "local_image_embedding", "image_latents_context_embedding"
184 )
185 new_checkpoint[diffusers_key] = unet_state_dict[k]
186
187 # temporal encoder.
188 new_checkpoint["image_latents_temporal_encoder.norm1.weight"] = unet_state_dict[
189 "local_temporal_encoder.layers.0.0.norm.weight"
190 ]
191 new_checkpoint["image_latents_temporal_encoder.norm1.bias"] = unet_state_dict[
192 "local_temporal_encoder.layers.0.0.norm.bias"

Callers 1

Calls 7

splitMethod · 0.80
renew_attention_pathsFunction · 0.70
assign_to_checkpointFunction · 0.70
renew_resnet_pathsFunction · 0.70
renew_temp_conv_pathsFunction · 0.70
shave_segmentsFunction · 0.70
popMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…