| 15 | from utils.utils import set_random_seed |
| 16 | |
| 17 | def parse_args(): |
| 18 | parser = argparse.ArgumentParser( |
| 19 | description= |
| 20 | "Finetune a transformers model on a causal language modeling task") |
| 21 | parser.add_argument('--data_path', |
| 22 | type=str, |
| 23 | required=True, |
| 24 | help='A json file store dataset path and weight') |
| 25 | parser.add_argument( |
| 26 | '--data_output_path', |
| 27 | type=str, |
| 28 | default='/tmp/data_files/', |
| 29 | help='Where to save the processed data.' |
| 30 | ) |
| 31 | parser.add_argument( |
| 32 | "--tokenizer_path", |
| 33 | type=str, |
| 34 | help= |
| 35 | "Path to the tokenizer", |
| 36 | required=True, |
| 37 | ) |
| 38 | parser.add_argument( |
| 39 | "--max_seq_len", |
| 40 | type=int, |
| 41 | default=512, |
| 42 | help="The maximum sequence length.", |
| 43 | ) |
| 44 | parser.add_argument("--seed", |
| 45 | type=int, |
| 46 | default=1234, |
| 47 | help="A seed for reproducible training.") |
| 48 | parser.add_argument("--user_token", |
| 49 | type=str, |
| 50 | default="<_user>", |
| 51 | help="user token") |
| 52 | parser.add_argument("--bot_token", |
| 53 | type=str, |
| 54 | default="<_bot>", |
| 55 | help="bot token") |
| 56 | parser.add_argument("--end_token", |
| 57 | type=str, |
| 58 | default="<_end>", |
| 59 | help="end token") |
| 60 | parser.add_argument("--num_workers", |
| 61 | type=int, |
| 62 | default=5, |
| 63 | help="Number of workers when tokenizing dataset") |
| 64 | parser.add_argument("--num_samples", |
| 65 | type=int, |
| 66 | required=True, |
| 67 | help="Number of samples while training") |
| 68 | parser.add_argument('--process_method', |
| 69 | choices=['single', 'multiple'], |
| 70 | required=True, |
| 71 | help='Choose the method (multiple process or single process) while processing dataset, note that' |
| 72 | 'when using both multi-process and multi-nodes, you should have a shared system.') |
| 73 | parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true") |
| 74 | args = parser.parse_args() |