Takes a state dict and a config, and returns a converted checkpoint.
(checkpoint, config, path=None, extract_ema=False)
| 133 | |
| 134 | |
| 135 | def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False): |
| 136 | """ |
| 137 | Takes a state dict and a config, and returns a converted checkpoint. |
| 138 | """ |
| 139 | |
| 140 | # extract state_dict for UNet |
| 141 | unet_state_dict = {} |
| 142 | keys = list(checkpoint.keys()) |
| 143 | |
| 144 | unet_key = "model.diffusion_model." |
| 145 | |
| 146 | # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA |
| 147 | if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: |
| 148 | print(f"Checkpoint {path} has both EMA and non-EMA weights.") |
| 149 | print( |
| 150 | "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" |
| 151 | " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." |
| 152 | ) |
| 153 | for key in keys: |
| 154 | if key.startswith("model.diffusion_model"): |
| 155 | flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) |
| 156 | unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) |
| 157 | else: |
| 158 | if sum(k.startswith("model_ema") for k in keys) > 100: |
| 159 | print( |
| 160 | "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" |
| 161 | " weights (usually better for inference), please make sure to add the `--extract_ema` flag." |
| 162 | ) |
| 163 | |
| 164 | for key in keys: |
| 165 | unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) |
| 166 | |
| 167 | new_checkpoint = {} |
| 168 | |
| 169 | new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] |
| 170 | new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] |
| 171 | new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] |
| 172 | new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] |
| 173 | |
| 174 | additional_embedding_substrings = [ |
| 175 | "local_image_concat", |
| 176 | "context_embedding", |
| 177 | "local_image_embedding", |
| 178 | "fps_embedding", |
| 179 | ] |
| 180 | for k in unet_state_dict: |
| 181 | if any(substring in k for substring in additional_embedding_substrings): |
| 182 | diffusers_key = k.replace("local_image_concat", "image_latents_proj_in").replace( |
| 183 | "local_image_embedding", "image_latents_context_embedding" |
| 184 | ) |
| 185 | new_checkpoint[diffusers_key] = unet_state_dict[k] |
| 186 | |
| 187 | # temporal encoder. |
| 188 | new_checkpoint["image_latents_temporal_encoder.norm1.weight"] = unet_state_dict[ |
| 189 | "local_temporal_encoder.layers.0.0.norm.weight" |
| 190 | ] |
| 191 | new_checkpoint["image_latents_temporal_encoder.norm1.bias"] = unet_state_dict[ |
| 192 | "local_temporal_encoder.layers.0.0.norm.bias" |
no test coverage detected
searching dependent graphs…