()
| 858 | |
| 859 | |
| 860 | def get_args(): |
| 861 | parser = argparse.ArgumentParser() |
| 862 | parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys())) |
| 863 | parser.add_argument( |
| 864 | "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" |
| 865 | ) |
| 866 | parser.add_argument( |
| 867 | "--vae_type", type=str, default="wan2.1", choices=["wan2.1", *list(VAE_CONFIGS.keys())], help="Type of VAE" |
| 868 | ) |
| 869 | parser.add_argument("--text_encoder_path", type=str, default=None) |
| 870 | parser.add_argument("--tokenizer_path", type=str, default=None) |
| 871 | parser.add_argument("--save_pipeline", action="store_true") |
| 872 | parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") |
| 873 | parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") |
| 874 | return parser.parse_args() |
| 875 | |
| 876 | |
| 877 | DTYPE_MAPPING = { |
no outgoing calls
no test coverage detected
searching dependent graphs…