(
ckpt_path: str,
num_layers: int,
num_attention_heads: int,
use_rotary_positional_embeddings: bool,
i2v: bool,
dtype: torch.dtype,
)
| 153 | |
| 154 | |
| 155 | def convert_transformer( |
| 156 | ckpt_path: str, |
| 157 | num_layers: int, |
| 158 | num_attention_heads: int, |
| 159 | use_rotary_positional_embeddings: bool, |
| 160 | i2v: bool, |
| 161 | dtype: torch.dtype, |
| 162 | ): |
| 163 | PREFIX_KEY = "model.diffusion_model." |
| 164 | |
| 165 | original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) |
| 166 | transformer = CogVideoXTransformer3DModel( |
| 167 | in_channels=32 if i2v else 16, |
| 168 | num_layers=num_layers, |
| 169 | num_attention_heads=num_attention_heads, |
| 170 | use_rotary_positional_embeddings=use_rotary_positional_embeddings, |
| 171 | use_learned_positional_embeddings=i2v, |
| 172 | ).to(dtype=dtype) |
| 173 | |
| 174 | for key in list(original_state_dict.keys()): |
| 175 | new_key = key[len(PREFIX_KEY) :] |
| 176 | for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): |
| 177 | new_key = new_key.replace(replace_key, rename_key) |
| 178 | update_state_dict_inplace(original_state_dict, key, new_key) |
| 179 | |
| 180 | for key in list(original_state_dict.keys()): |
| 181 | for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): |
| 182 | if special_key not in key: |
| 183 | continue |
| 184 | handler_fn_inplace(key, original_state_dict) |
| 185 | transformer.load_state_dict(original_state_dict, strict=True) |
| 186 | return transformer |
| 187 | |
| 188 | |
| 189 | def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype): |
no test coverage detected