(vae_state_dict)
| 179 | |
| 180 | |
| 181 | def convert_vae_state_dict(vae_state_dict): |
| 182 | mapping = {k: k for k in vae_state_dict.keys()} |
| 183 | for k, v in mapping.items(): |
| 184 | for sd_part, hf_part in vae_conversion_map: |
| 185 | v = v.replace(hf_part, sd_part) |
| 186 | mapping[k] = v |
| 187 | for k, v in mapping.items(): |
| 188 | if "attentions" in k: |
| 189 | for sd_part, hf_part in vae_conversion_map_attn: |
| 190 | v = v.replace(hf_part, sd_part) |
| 191 | mapping[k] = v |
| 192 | new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} |
| 193 | weights_to_convert = ["q", "k", "v", "proj_out"] |
| 194 | keys_to_rename = {} |
| 195 | for k, v in new_state_dict.items(): |
| 196 | for weight_name in weights_to_convert: |
| 197 | if f"mid.attn_1.{weight_name}.weight" in k: |
| 198 | print(f"Reshaping {k} for SD format") |
| 199 | new_state_dict[k] = reshape_weight_for_sd(v) |
| 200 | for weight_name, real_weight_name in vae_extra_conversion_map: |
| 201 | if f"mid.attn_1.{weight_name}.weight" in k or f"mid.attn_1.{weight_name}.bias" in k: |
| 202 | keys_to_rename[k] = k.replace(weight_name, real_weight_name) |
| 203 | for k, v in keys_to_rename.items(): |
| 204 | if k in new_state_dict: |
| 205 | print(f"Renaming {k} to {v}") |
| 206 | new_state_dict[v] = reshape_weight_for_sd(new_state_dict[k]) |
| 207 | del new_state_dict[k] |
| 208 | return new_state_dict |
| 209 | |
| 210 | |
| 211 | # =========================# |
no test coverage detected