MCPcopy Index your code
hub / github.com/huggingface/diffusers / parse_args

Function parse_args

examples/discrete_diffusion/train_llada2.py:68–102  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

66
67
68def parse_args() -> TrainConfig:
69 parser = argparse.ArgumentParser(description="Train block-refinement with a confidence-aware loss on a causal LM.")
70
71 parser.add_argument("--model_name_or_path", type=str, default="Qwen/Qwen2.5-0.5B")
72 parser.add_argument("--dataset_name", type=str, default="wikitext")
73 parser.add_argument("--dataset_config_name", type=str, default="wikitext-2-raw-v1")
74 parser.add_argument("--text_column", type=str, default="text")
75 parser.add_argument("--cache_dir", type=str, default=None)
76 parser.add_argument("--use_dummy_data", action="store_true", help="Use random-token data instead of downloading.")
77 parser.add_argument("--num_dummy_samples", type=int, default=2048)
78
79 parser.add_argument("--output_dir", type=str, default="block-refinement-output")
80 parser.add_argument("--seed", type=int, default=0)
81 parser.add_argument("--max_train_steps", type=int, default=1000)
82 parser.add_argument("--checkpointing_steps", type=int, default=500)
83 parser.add_argument("--logging_steps", type=int, default=50)
84
85 parser.add_argument("--per_device_train_batch_size", type=int, default=1)
86 parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
87 parser.add_argument("--learning_rate", type=float, default=2e-5)
88 parser.add_argument("--weight_decay", type=float, default=0.0)
89 parser.add_argument(
90 "--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"]
91 )
92 parser.add_argument("--lr_warmup_steps", type=int, default=100)
93
94 parser.add_argument("--max_length", type=int, default=256)
95 parser.add_argument("--prompt_length", type=int, default=32)
96 parser.add_argument("--block_length", type=int, default=32)
97
98 parser.add_argument("--lambda_conf", type=float, default=2.0)
99 parser.add_argument("--conf_temperature", type=float, default=0.5)
100
101 args = parser.parse_args()
102 return TrainConfig(**vars(args))
103
104
105def tokenize_fn(examples: Dict, tokenizer, text_column: str, max_length: int):

Callers 1

mainFunction · 0.70

Calls 1

TrainConfigClass · 0.85

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…