(original_state_dict: dict[str, Any], version: str)
| 801 | |
| 802 | |
| 803 | def convert_ltx2_vocoder(original_state_dict: dict[str, Any], version: str) -> dict[str, Any]: |
| 804 | config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version) |
| 805 | diffusers_config = config["diffusers_config"] |
| 806 | if version == "2.3": |
| 807 | vocoder_cls = LTX2VocoderWithBWE |
| 808 | else: |
| 809 | vocoder_cls = LTX2Vocoder |
| 810 | |
| 811 | with init_empty_weights(): |
| 812 | vocoder = vocoder_cls.from_config(diffusers_config) |
| 813 | |
| 814 | # Handle official code --> diffusers key remapping via the remap dict |
| 815 | for key in list(original_state_dict.keys()): |
| 816 | new_key = key[:] |
| 817 | for replace_key, rename_key in rename_dict.items(): |
| 818 | new_key = new_key.replace(replace_key, rename_key) |
| 819 | update_state_dict_inplace(original_state_dict, key, new_key) |
| 820 | |
| 821 | # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in |
| 822 | # special_keys_remap |
| 823 | for key in list(original_state_dict.keys()): |
| 824 | for special_key, handler_fn_inplace in special_keys_remap.items(): |
| 825 | if special_key not in key: |
| 826 | continue |
| 827 | handler_fn_inplace(key, original_state_dict) |
| 828 | |
| 829 | vocoder.load_state_dict(original_state_dict, strict=True, assign=True) |
| 830 | return vocoder |
| 831 | |
| 832 | |
| 833 | def get_ltx2_spatial_latent_upsampler_config(version: str): |
no test coverage detected
searching dependent graphs…