MCPcopy Index your code
hub / github.com/huggingface/diffusers / main

Function main

scripts/convert_cogview3_to_diffusers.py:158–238  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

156
157
158def main(args):
159 if args.dtype == "fp16":
160 dtype = torch.float16
161 elif args.dtype == "bf16":
162 dtype = torch.bfloat16
163 elif args.dtype == "fp32":
164 dtype = torch.float32
165 else:
166 raise ValueError(f"Unsupported dtype: {args.dtype}")
167
168 transformer = None
169 vae = None
170
171 if args.transformer_checkpoint_path is not None:
172 converted_transformer_state_dict = convert_cogview3_transformer_checkpoint_to_diffusers(
173 args.transformer_checkpoint_path
174 )
175 transformer = CogView3PlusTransformer2DModel()
176 transformer.load_state_dict(converted_transformer_state_dict, strict=True)
177 if dtype is not None:
178 # Original checkpoint data type will be preserved
179 transformer = transformer.to(dtype=dtype)
180
181 if args.vae_checkpoint_path is not None:
182 vae_config = {
183 "in_channels": 3,
184 "out_channels": 3,
185 "down_block_types": ("DownEncoderBlock2D",) * 4,
186 "up_block_types": ("UpDecoderBlock2D",) * 4,
187 "block_out_channels": (128, 512, 1024, 1024),
188 "layers_per_block": 3,
189 "act_fn": "silu",
190 "latent_channels": 16,
191 "norm_num_groups": 32,
192 "sample_size": 1024,
193 "scaling_factor": 1.0,
194 "force_upcast": True,
195 "use_quant_conv": False,
196 "use_post_quant_conv": False,
197 "mid_block_add_attention": False,
198 }
199 converted_vae_state_dict = convert_cogview3_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config)
200 vae = AutoencoderKL(**vae_config)
201 vae.load_state_dict(converted_vae_state_dict, strict=True)
202 if dtype is not None:
203 vae = vae.to(dtype=dtype)
204
205 text_encoder_id = "google/t5-v1_1-xxl"
206 tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
207 text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
208
209 # Apparently, the conversion does not work anymore without this :shrug:
210 for param in text_encoder.parameters():
211 param.data = param.data.contiguous()
212
213 scheduler = CogVideoXDDIMScheduler.from_config(
214 {
215 "snr_shift_scale": 4.0,

Calls 11

AutoencoderKLClass · 0.90
parametersMethod · 0.80
load_state_dictMethod · 0.45
toMethod · 0.45
from_pretrainedMethod · 0.45
from_configMethod · 0.45
save_pretrainedMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…