(vae_type: str)
| 716 | |
| 717 | |
| 718 | def convert_vae(vae_type: str): |
| 719 | model_name = VAE_CONFIGS[vae_type]["name"] |
| 720 | snapshot_directory = snapshot_download(model_name, repo_type="model") |
| 721 | directory = pathlib.Path(snapshot_directory) |
| 722 | |
| 723 | autoencoder_file = directory / "autoencoder.jit" |
| 724 | mean_std_file = directory / "mean_std.pt" |
| 725 | |
| 726 | original_state_dict = torch.jit.load(autoencoder_file.as_posix()).state_dict() |
| 727 | if mean_std_file.exists(): |
| 728 | mean_std = torch.load(mean_std_file, map_location="cpu", weights_only=True) |
| 729 | else: |
| 730 | mean_std = (None, None) |
| 731 | |
| 732 | config = VAE_CONFIGS[vae_type]["diffusers_config"] |
| 733 | config.update( |
| 734 | { |
| 735 | "latents_mean": mean_std[0].detach().cpu().numpy().tolist(), |
| 736 | "latents_std": mean_std[1].detach().cpu().numpy().tolist(), |
| 737 | } |
| 738 | ) |
| 739 | vae = AutoencoderKLCosmos(**config) |
| 740 | |
| 741 | for key in list(original_state_dict.keys()): |
| 742 | new_key = key[:] |
| 743 | for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): |
| 744 | new_key = new_key.replace(replace_key, rename_key) |
| 745 | update_state_dict_(original_state_dict, key, new_key) |
| 746 | |
| 747 | for key in list(original_state_dict.keys()): |
| 748 | for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): |
| 749 | if special_key not in key: |
| 750 | continue |
| 751 | handler_fn_inplace(key, original_state_dict) |
| 752 | |
| 753 | vae.load_state_dict(original_state_dict, strict=True, assign=True) |
| 754 | return vae |
| 755 | |
| 756 | |
| 757 | def save_pipeline_cosmos_1_0(args, transformer, vae): |
no test coverage detected
searching dependent graphs…