MCPcopy
hub / github.com/borisdayma/dalle-mini / main

Function main

tools/train/train.py:636–1660  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

634
635
636def main():
637 # See all possible arguments by passing the --help flag to this script.
638 parser = HfArgumentParser(
639 (ModelArguments, DataTrainingArguments, TrainingArguments)
640 )
641 if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
642 # If we pass only one argument to the script and it's the path to a json file,
643 # let's parse it to get our arguments.
644 model_args, data_args, training_args = parser.parse_json_file(
645 json_file=os.path.abspath(sys.argv[1])
646 )
647 else:
648 model_args, data_args, training_args = parser.parse_args_into_dataclasses()
649
650 # check arguments
651 if training_args.mp_devices > jax.local_device_count():
652 assert (
653 data_args.seed_dataset is not None
654 ), "Seed dataset must be provided when model is split over multiple hosts"
655
656 # Make one log on every process with the configuration for debugging.
657 logging.basicConfig(
658 format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
659 datefmt="%m/%d/%Y %H:%M:%S",
660 level=logging.INFO,
661 )
662 # Setup logging, we only want one process per machine to log things on the screen.
663 logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
664 if jax.process_index() == 0:
665 datasets.utils.logging.set_verbosity_warning()
666 transformers.utils.logging.set_verbosity_info()
667 else:
668 datasets.utils.logging.set_verbosity_error()
669 transformers.utils.logging.set_verbosity_error()
670
671 # Set the verbosity to info of the Transformers logger (on main process only):
672 logger.info(f"Training/evaluation parameters {training_args}")
673
674 # Load dataset
675 dataset = Dataset(
676 **asdict(data_args),
677 do_train=training_args.do_train,
678 do_eval=training_args.do_eval,
679 )
680
681 logger.info(f"Local TPUs: {jax.local_device_count()}")
682 logger.info(f"Global TPUs: {jax.device_count()}")
683
684 # Set up wandb run
685 if jax.process_index() == 0:
686 wandb.init(
687 entity=training_args.wandb_entity,
688 project=training_args.wandb_project,
689 job_type=training_args.wandb_job_type,
690 config=parser.parse_args(),
691 )
692
693 # Set up our new model config

Callers 1

train.pyFile · 0.85

Calls 15

preprocessMethod · 0.95
num_paramsMethod · 0.95
update_state_metricsMethod · 0.95
logMethod · 0.95
dataloaderMethod · 0.95
DatasetClass · 0.90
DalleBartClass · 0.90
set_partitionsFunction · 0.90
distributed_shampooFunction · 0.90
create_learning_rate_fnFunction · 0.85
split_paramsFunction · 0.85

Tested by

no test coverage detected