MCPcopy Index your code
hub / github.com/huggingface/diffusers / main

Function main

scripts/convert_mochi_to_diffusers.py:411–459  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

409
410
411def main(args):
412 if args.dtype is None:
413 dtype = None
414 if args.dtype == "fp16":
415 dtype = torch.float16
416 elif args.dtype == "bf16":
417 dtype = torch.bfloat16
418 elif args.dtype == "fp32":
419 dtype = torch.float32
420 else:
421 raise ValueError(f"Unsupported dtype: {args.dtype}")
422
423 transformer = None
424 vae = None
425
426 if args.transformer_checkpoint_path is not None:
427 converted_transformer_state_dict = convert_mochi_transformer_checkpoint_to_diffusers(
428 args.transformer_checkpoint_path
429 )
430 transformer = MochiTransformer3DModel()
431 transformer.load_state_dict(converted_transformer_state_dict, strict=True)
432 if dtype is not None:
433 transformer = transformer.to(dtype=dtype)
434
435 if args.vae_encoder_checkpoint_path is not None and args.vae_decoder_checkpoint_path is not None:
436 vae = AutoencoderKLMochi(latent_channels=12, out_channels=3)
437 converted_vae_state_dict = convert_mochi_vae_state_dict_to_diffusers(
438 args.vae_encoder_checkpoint_path, args.vae_decoder_checkpoint_path
439 )
440 vae.load_state_dict(converted_vae_state_dict, strict=True)
441 if dtype is not None:
442 vae = vae.to(dtype=dtype)
443
444 text_encoder_id = "google/t5-v1_1-xxl"
445 tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
446 text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
447
448 # Apparently, the conversion does not work anymore without this :shrug:
449 for param in text_encoder.parameters():
450 param.data = param.data.contiguous()
451
452 pipe = MochiPipeline(
453 scheduler=FlowMatchEulerDiscreteScheduler(invert_sigmas=True),
454 vae=vae,
455 text_encoder=text_encoder,
456 tokenizer=tokenizer,
457 transformer=transformer,
458 )
459 pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
460
461
462if __name__ == "__main__":

Callers 1

Calls 11

AutoencoderKLMochiClass · 0.90
MochiPipelineClass · 0.90
parametersMethod · 0.80
load_state_dictMethod · 0.45
toMethod · 0.45
from_pretrainedMethod · 0.45
save_pretrainedMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…