(args)
| 266 | |
| 267 | |
| 268 | def main(args): |
| 269 | original_ckpt = load_original_checkpoint(args.checkpoint_path) |
| 270 | original_dtype = next(iter(original_ckpt.values())).dtype |
| 271 | |
| 272 | # Initialize dtype with a default value |
| 273 | dtype = None |
| 274 | |
| 275 | if args.dtype is None: |
| 276 | dtype = original_dtype |
| 277 | elif args.dtype == "fp16": |
| 278 | dtype = torch.float16 |
| 279 | elif args.dtype == "bf16": |
| 280 | dtype = torch.bfloat16 |
| 281 | elif args.dtype == "fp32": |
| 282 | dtype = torch.float32 |
| 283 | else: |
| 284 | raise ValueError(f"Unsupported dtype: {args.dtype}") |
| 285 | |
| 286 | if dtype != original_dtype: |
| 287 | print( |
| 288 | f"Checkpoint dtype {original_dtype} does not match requested dtype {dtype}. This can lead to unexpected results, proceed with caution." |
| 289 | ) |
| 290 | |
| 291 | num_layers = list(set(int(k.split(".", 2)[1]) for k in original_ckpt if "joint_blocks" in k))[-1] + 1 # noqa: C401 |
| 292 | |
| 293 | caption_projection_dim = get_caption_projection_dim(original_ckpt) |
| 294 | |
| 295 | # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5 |
| 296 | attn2_layers = get_attn2_layers(original_ckpt) |
| 297 | |
| 298 | # sd3.5 use qk norm("rms_norm") |
| 299 | has_qk_norm = any("ln_q" in key for key in original_ckpt.keys()) |
| 300 | |
| 301 | # sd3.5 2b use pox_embed_max_size=384 and sd3.0 and sd3.5 8b use 192 |
| 302 | pos_embed_max_size = get_pos_embed_max_size(original_ckpt) |
| 303 | |
| 304 | converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers( |
| 305 | original_ckpt, num_layers, caption_projection_dim, attn2_layers, has_qk_norm |
| 306 | ) |
| 307 | |
| 308 | with CTX(): |
| 309 | transformer = SD3Transformer2DModel( |
| 310 | sample_size=128, |
| 311 | patch_size=2, |
| 312 | in_channels=16, |
| 313 | joint_attention_dim=4096, |
| 314 | num_layers=num_layers, |
| 315 | caption_projection_dim=caption_projection_dim, |
| 316 | num_attention_heads=num_layers, |
| 317 | pos_embed_max_size=pos_embed_max_size, |
| 318 | qk_norm="rms_norm" if has_qk_norm else None, |
| 319 | dual_attention_layers=attn2_layers, |
| 320 | ) |
| 321 | if is_accelerate_available(): |
| 322 | load_model_dict_into_meta(transformer, converted_transformer_state_dict) |
| 323 | else: |
| 324 | transformer.load_state_dict(converted_transformer_state_dict, strict=True) |
| 325 |
no test coverage detected
searching dependent graphs…