| 28 | ) |
| 29 | |
| 30 | def parse_args(): |
| 31 | argument_parser = ArgumentParser( |
| 32 | description="Pretrain projection layer of DeTikZify." |
| 33 | ) |
| 34 | argument_parser.add_argument("--base_model", |
| 35 | required=True, |
| 36 | help="The model checkpoint for weights initialization." |
| 37 | ) |
| 38 | argument_parser.add_argument("--size", |
| 39 | default=1_000_000, |
| 40 | type=int, |
| 41 | help="the amount of figures to use for pretraining" |
| 42 | ) |
| 43 | argument_parser.add_argument("--output", |
| 44 | required=True, |
| 45 | help="directory where to write the model files", |
| 46 | ) |
| 47 | argument_parser.add_argument("--deepspeed", |
| 48 | help="path to a DeepSpeed json config file", |
| 49 | ) |
| 50 | argument_parser.add_argument("--gradient_checkpointing", |
| 51 | action="store_true", |
| 52 | help="use gradient checkpointing", |
| 53 | ) |
| 54 | |
| 55 | return argument_parser.parse_args() |
| 56 | |
| 57 | if __name__ == "__main__": |
| 58 | set_verbosity_info() |