MCPcopy Index your code
hub / github.com/huggingface/diffusers / main

Function main

scripts/convert_sd3_to_diffusers.py:268–347  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

266
267
268def main(args):
269 original_ckpt = load_original_checkpoint(args.checkpoint_path)
270 original_dtype = next(iter(original_ckpt.values())).dtype
271
272 # Initialize dtype with a default value
273 dtype = None
274
275 if args.dtype is None:
276 dtype = original_dtype
277 elif args.dtype == "fp16":
278 dtype = torch.float16
279 elif args.dtype == "bf16":
280 dtype = torch.bfloat16
281 elif args.dtype == "fp32":
282 dtype = torch.float32
283 else:
284 raise ValueError(f"Unsupported dtype: {args.dtype}")
285
286 if dtype != original_dtype:
287 print(
288 f"Checkpoint dtype {original_dtype} does not match requested dtype {dtype}. This can lead to unexpected results, proceed with caution."
289 )
290
291 num_layers = list(set(int(k.split(".", 2)[1]) for k in original_ckpt if "joint_blocks" in k))[-1] + 1 # noqa: C401
292
293 caption_projection_dim = get_caption_projection_dim(original_ckpt)
294
295 # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
296 attn2_layers = get_attn2_layers(original_ckpt)
297
298 # sd3.5 use qk norm("rms_norm")
299 has_qk_norm = any("ln_q" in key for key in original_ckpt.keys())
300
301 # sd3.5 2b use pox_embed_max_size=384 and sd3.0 and sd3.5 8b use 192
302 pos_embed_max_size = get_pos_embed_max_size(original_ckpt)
303
304 converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers(
305 original_ckpt, num_layers, caption_projection_dim, attn2_layers, has_qk_norm
306 )
307
308 with CTX():
309 transformer = SD3Transformer2DModel(
310 sample_size=128,
311 patch_size=2,
312 in_channels=16,
313 joint_attention_dim=4096,
314 num_layers=num_layers,
315 caption_projection_dim=caption_projection_dim,
316 num_attention_heads=num_layers,
317 pos_embed_max_size=pos_embed_max_size,
318 qk_norm="rms_norm" if has_qk_norm else None,
319 dual_attention_layers=attn2_layers,
320 )
321 if is_accelerate_available():
322 load_model_dict_into_meta(transformer, converted_transformer_state_dict)
323 else:
324 transformer.load_state_dict(converted_transformer_state_dict, strict=True)
325

Callers 1

Calls 15

is_accelerate_availableFunction · 0.90
get_pos_embed_max_sizeFunction · 0.85
is_vae_in_checkpointFunction · 0.85
splitMethod · 0.80
load_original_checkpointFunction · 0.70
get_attn2_layersFunction · 0.70
load_state_dictMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…