MCPcopy
hub / github.com/zai-org/CogVideo / convert_transformer

Function convert_transformer

tools/convert_weight_sat2hf.py:155–186  ·  view source on GitHub ↗
(
    ckpt_path: str,
    num_layers: int,
    num_attention_heads: int,
    use_rotary_positional_embeddings: bool,
    i2v: bool,
    dtype: torch.dtype,
)

Source from the content-addressed store, hash-verified

153
154
155def convert_transformer(
156 ckpt_path: str,
157 num_layers: int,
158 num_attention_heads: int,
159 use_rotary_positional_embeddings: bool,
160 i2v: bool,
161 dtype: torch.dtype,
162):
163 PREFIX_KEY = "model.diffusion_model."
164
165 original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
166 transformer = CogVideoXTransformer3DModel(
167 in_channels=32 if i2v else 16,
168 num_layers=num_layers,
169 num_attention_heads=num_attention_heads,
170 use_rotary_positional_embeddings=use_rotary_positional_embeddings,
171 use_learned_positional_embeddings=i2v,
172 ).to(dtype=dtype)
173
174 for key in list(original_state_dict.keys()):
175 new_key = key[len(PREFIX_KEY) :]
176 for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
177 new_key = new_key.replace(replace_key, rename_key)
178 update_state_dict_inplace(original_state_dict, key, new_key)
179
180 for key in list(original_state_dict.keys()):
181 for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
182 if special_key not in key:
183 continue
184 handler_fn_inplace(key, original_state_dict)
185 transformer.load_state_dict(original_state_dict, strict=True)
186 return transformer
187
188
189def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):

Callers 1

Calls 4

loadMethod · 0.80
load_state_dictMethod · 0.80
get_state_dictFunction · 0.70

Tested by

no test coverage detected