()
| 634 | |
| 635 | |
| 636 | def main(): |
| 637 | # See all possible arguments by passing the --help flag to this script. |
| 638 | parser = HfArgumentParser( |
| 639 | (ModelArguments, DataTrainingArguments, TrainingArguments) |
| 640 | ) |
| 641 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): |
| 642 | # If we pass only one argument to the script and it's the path to a json file, |
| 643 | # let's parse it to get our arguments. |
| 644 | model_args, data_args, training_args = parser.parse_json_file( |
| 645 | json_file=os.path.abspath(sys.argv[1]) |
| 646 | ) |
| 647 | else: |
| 648 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
| 649 | |
| 650 | # check arguments |
| 651 | if training_args.mp_devices > jax.local_device_count(): |
| 652 | assert ( |
| 653 | data_args.seed_dataset is not None |
| 654 | ), "Seed dataset must be provided when model is split over multiple hosts" |
| 655 | |
| 656 | # Make one log on every process with the configuration for debugging. |
| 657 | logging.basicConfig( |
| 658 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| 659 | datefmt="%m/%d/%Y %H:%M:%S", |
| 660 | level=logging.INFO, |
| 661 | ) |
| 662 | # Setup logging, we only want one process per machine to log things on the screen. |
| 663 | logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) |
| 664 | if jax.process_index() == 0: |
| 665 | datasets.utils.logging.set_verbosity_warning() |
| 666 | transformers.utils.logging.set_verbosity_info() |
| 667 | else: |
| 668 | datasets.utils.logging.set_verbosity_error() |
| 669 | transformers.utils.logging.set_verbosity_error() |
| 670 | |
| 671 | # Set the verbosity to info of the Transformers logger (on main process only): |
| 672 | logger.info(f"Training/evaluation parameters {training_args}") |
| 673 | |
| 674 | # Load dataset |
| 675 | dataset = Dataset( |
| 676 | **asdict(data_args), |
| 677 | do_train=training_args.do_train, |
| 678 | do_eval=training_args.do_eval, |
| 679 | ) |
| 680 | |
| 681 | logger.info(f"Local TPUs: {jax.local_device_count()}") |
| 682 | logger.info(f"Global TPUs: {jax.device_count()}") |
| 683 | |
| 684 | # Set up wandb run |
| 685 | if jax.process_index() == 0: |
| 686 | wandb.init( |
| 687 | entity=training_args.wandb_entity, |
| 688 | project=training_args.wandb_project, |
| 689 | job_type=training_args.wandb_job_type, |
| 690 | config=parser.parse_args(), |
| 691 | ) |
| 692 | |
| 693 | # Set up our new model config |
no test coverage detected