(args)
| 156 | |
| 157 | |
| 158 | def main(args): |
| 159 | if args.dtype == "fp16": |
| 160 | dtype = torch.float16 |
| 161 | elif args.dtype == "bf16": |
| 162 | dtype = torch.bfloat16 |
| 163 | elif args.dtype == "fp32": |
| 164 | dtype = torch.float32 |
| 165 | else: |
| 166 | raise ValueError(f"Unsupported dtype: {args.dtype}") |
| 167 | |
| 168 | transformer = None |
| 169 | vae = None |
| 170 | |
| 171 | if args.transformer_checkpoint_path is not None: |
| 172 | converted_transformer_state_dict = convert_cogview3_transformer_checkpoint_to_diffusers( |
| 173 | args.transformer_checkpoint_path |
| 174 | ) |
| 175 | transformer = CogView3PlusTransformer2DModel() |
| 176 | transformer.load_state_dict(converted_transformer_state_dict, strict=True) |
| 177 | if dtype is not None: |
| 178 | # Original checkpoint data type will be preserved |
| 179 | transformer = transformer.to(dtype=dtype) |
| 180 | |
| 181 | if args.vae_checkpoint_path is not None: |
| 182 | vae_config = { |
| 183 | "in_channels": 3, |
| 184 | "out_channels": 3, |
| 185 | "down_block_types": ("DownEncoderBlock2D",) * 4, |
| 186 | "up_block_types": ("UpDecoderBlock2D",) * 4, |
| 187 | "block_out_channels": (128, 512, 1024, 1024), |
| 188 | "layers_per_block": 3, |
| 189 | "act_fn": "silu", |
| 190 | "latent_channels": 16, |
| 191 | "norm_num_groups": 32, |
| 192 | "sample_size": 1024, |
| 193 | "scaling_factor": 1.0, |
| 194 | "force_upcast": True, |
| 195 | "use_quant_conv": False, |
| 196 | "use_post_quant_conv": False, |
| 197 | "mid_block_add_attention": False, |
| 198 | } |
| 199 | converted_vae_state_dict = convert_cogview3_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config) |
| 200 | vae = AutoencoderKL(**vae_config) |
| 201 | vae.load_state_dict(converted_vae_state_dict, strict=True) |
| 202 | if dtype is not None: |
| 203 | vae = vae.to(dtype=dtype) |
| 204 | |
| 205 | text_encoder_id = "google/t5-v1_1-xxl" |
| 206 | tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) |
| 207 | text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) |
| 208 | |
| 209 | # Apparently, the conversion does not work anymore without this :shrug: |
| 210 | for param in text_encoder.parameters(): |
| 211 | param.data = param.data.contiguous() |
| 212 | |
| 213 | scheduler = CogVideoXDDIMScheduler.from_config( |
| 214 | { |
| 215 | "snr_shift_scale": 4.0, |
no test coverage detected
searching dependent graphs…