MCPcopy
hub / github.com/huggingface/diffusers / convert_vae

Function convert_vae

scripts/convert_cosmos_to_diffusers.py:718–754  ·  view source on GitHub ↗
(vae_type: str)

Source from the content-addressed store, hash-verified

716
717
718def convert_vae(vae_type: str):
719 model_name = VAE_CONFIGS[vae_type]["name"]
720 snapshot_directory = snapshot_download(model_name, repo_type="model")
721 directory = pathlib.Path(snapshot_directory)
722
723 autoencoder_file = directory / "autoencoder.jit"
724 mean_std_file = directory / "mean_std.pt"
725
726 original_state_dict = torch.jit.load(autoencoder_file.as_posix()).state_dict()
727 if mean_std_file.exists():
728 mean_std = torch.load(mean_std_file, map_location="cpu", weights_only=True)
729 else:
730 mean_std = (None, None)
731
732 config = VAE_CONFIGS[vae_type]["diffusers_config"]
733 config.update(
734 {
735 "latents_mean": mean_std[0].detach().cpu().numpy().tolist(),
736 "latents_std": mean_std[1].detach().cpu().numpy().tolist(),
737 }
738 )
739 vae = AutoencoderKLCosmos(**config)
740
741 for key in list(original_state_dict.keys()):
742 new_key = key[:]
743 for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
744 new_key = new_key.replace(replace_key, rename_key)
745 update_state_dict_(original_state_dict, key, new_key)
746
747 for key in list(original_state_dict.keys()):
748 for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
749 if special_key not in key:
750 continue
751 handler_fn_inplace(key, original_state_dict)
752
753 vae.load_state_dict(original_state_dict, strict=True, assign=True)
754 return vae
755
756
757def save_pipeline_cosmos_1_0(args, transformer, vae):

Callers 1

Calls 7

AutoencoderKLCosmosClass · 0.90
existsMethod · 0.80
update_state_dict_Function · 0.70
state_dictMethod · 0.45
loadMethod · 0.45
updateMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…