MCPcopy Index your code
hub / github.com/XPixelGroup/DiffBIR / convert_vae_state_dict

Function convert_vae_state_dict

scripts/convert_diffusers_to_sd.py:181–208  ·  view source on GitHub ↗
(vae_state_dict)

Source from the content-addressed store, hash-verified

179
180
181def convert_vae_state_dict(vae_state_dict):
182 mapping = {k: k for k in vae_state_dict.keys()}
183 for k, v in mapping.items():
184 for sd_part, hf_part in vae_conversion_map:
185 v = v.replace(hf_part, sd_part)
186 mapping[k] = v
187 for k, v in mapping.items():
188 if "attentions" in k:
189 for sd_part, hf_part in vae_conversion_map_attn:
190 v = v.replace(hf_part, sd_part)
191 mapping[k] = v
192 new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
193 weights_to_convert = ["q", "k", "v", "proj_out"]
194 keys_to_rename = {}
195 for k, v in new_state_dict.items():
196 for weight_name in weights_to_convert:
197 if f"mid.attn_1.{weight_name}.weight" in k:
198 print(f"Reshaping {k} for SD format")
199 new_state_dict[k] = reshape_weight_for_sd(v)
200 for weight_name, real_weight_name in vae_extra_conversion_map:
201 if f"mid.attn_1.{weight_name}.weight" in k or f"mid.attn_1.{weight_name}.bias" in k:
202 keys_to_rename[k] = k.replace(weight_name, real_weight_name)
203 for k, v in keys_to_rename.items():
204 if k in new_state_dict:
205 print(f"Renaming {k} to {v}")
206 new_state_dict[v] = reshape_weight_for_sd(new_state_dict[k])
207 del new_state_dict[k]
208 return new_state_dict
209
210
211# =========================#

Callers 1

Calls 1

reshape_weight_for_sdFunction · 0.85

Tested by

no test coverage detected