()
| 66 | |
| 67 | |
| 68 | def parse_args() -> TrainConfig: |
| 69 | parser = argparse.ArgumentParser(description="Train block-refinement with a confidence-aware loss on a causal LM.") |
| 70 | |
| 71 | parser.add_argument("--model_name_or_path", type=str, default="Qwen/Qwen2.5-0.5B") |
| 72 | parser.add_argument("--dataset_name", type=str, default="wikitext") |
| 73 | parser.add_argument("--dataset_config_name", type=str, default="wikitext-2-raw-v1") |
| 74 | parser.add_argument("--text_column", type=str, default="text") |
| 75 | parser.add_argument("--cache_dir", type=str, default=None) |
| 76 | parser.add_argument("--use_dummy_data", action="store_true", help="Use random-token data instead of downloading.") |
| 77 | parser.add_argument("--num_dummy_samples", type=int, default=2048) |
| 78 | |
| 79 | parser.add_argument("--output_dir", type=str, default="block-refinement-output") |
| 80 | parser.add_argument("--seed", type=int, default=0) |
| 81 | parser.add_argument("--max_train_steps", type=int, default=1000) |
| 82 | parser.add_argument("--checkpointing_steps", type=int, default=500) |
| 83 | parser.add_argument("--logging_steps", type=int, default=50) |
| 84 | |
| 85 | parser.add_argument("--per_device_train_batch_size", type=int, default=1) |
| 86 | parser.add_argument("--gradient_accumulation_steps", type=int, default=8) |
| 87 | parser.add_argument("--learning_rate", type=float, default=2e-5) |
| 88 | parser.add_argument("--weight_decay", type=float, default=0.0) |
| 89 | parser.add_argument( |
| 90 | "--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"] |
| 91 | ) |
| 92 | parser.add_argument("--lr_warmup_steps", type=int, default=100) |
| 93 | |
| 94 | parser.add_argument("--max_length", type=int, default=256) |
| 95 | parser.add_argument("--prompt_length", type=int, default=32) |
| 96 | parser.add_argument("--block_length", type=int, default=32) |
| 97 | |
| 98 | parser.add_argument("--lambda_conf", type=float, default=2.0) |
| 99 | parser.add_argument("--conf_temperature", type=float, default=0.5) |
| 100 | |
| 101 | args = parser.parse_args() |
| 102 | return TrainConfig(**vars(args)) |
| 103 | |
| 104 | |
| 105 | def tokenize_fn(examples: Dict, tokenizer, text_column: str, max_length: int): |
no test coverage detected
searching dependent graphs…