()
| 42 | |
| 43 | |
| 44 | def main(): |
| 45 | parser = StarChatArgumentParser((ModelArguments, DataArguments, TrainingArguments)) |
| 46 | if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): |
| 47 | # If we pass only one argument to the script and it's the path to a YAML file, |
| 48 | # let's parse it to get our arguments. |
| 49 | model_args, data_args, training_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1])) |
| 50 | # parse command line args and yaml file |
| 51 | elif len(sys.argv) > 2 and sys.argv[1].endswith(".yaml"): |
| 52 | model_args, data_args, training_args = parser.parse_yaml_and_args(os.path.abspath(sys.argv[1]), sys.argv[2:]) |
| 53 | # parse command line args only |
| 54 | else: |
| 55 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
| 56 | |
| 57 | # Set seed for reproducibility |
| 58 | set_seed(training_args.seed) |
| 59 | |
| 60 | ############### |
| 61 | # Setup logging |
| 62 | ############### |
| 63 | logging.basicConfig( |
| 64 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| 65 | datefmt="%Y-%m-%d %H:%M:%S", |
| 66 | handlers=[logging.StreamHandler(sys.stdout)], |
| 67 | ) |
| 68 | log_level = training_args.get_process_log_level() |
| 69 | logger.setLevel(log_level) |
| 70 | datasets.utils.logging.set_verbosity(log_level) |
| 71 | transformers.utils.logging.set_verbosity(log_level) |
| 72 | transformers.utils.logging.enable_default_handler() |
| 73 | transformers.utils.logging.enable_explicit_format() |
| 74 | |
| 75 | # Log on each process a small summary |
| 76 | logger.warning( |
| 77 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" |
| 78 | + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" |
| 79 | ) |
| 80 | logger.info(f"Model parameters {model_args}") |
| 81 | logger.info(f"Data parameters {data_args}") |
| 82 | logger.info(f"Training/evaluation parameters {training_args}") |
| 83 | |
| 84 | # Login to HuggingFace Hub if needed |
| 85 | hf_login() |
| 86 | |
| 87 | ########################### |
| 88 | # Detecting last checkpoint |
| 89 | ########################### |
| 90 | last_checkpoint = None |
| 91 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: |
| 92 | last_checkpoint = get_last_checkpoint(training_args.output_dir) |
| 93 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: |
| 94 | raise ValueError( |
| 95 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " |
| 96 | "Use --overwrite_output_dir to overcome." |
| 97 | ) |
| 98 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: |
| 99 | logger.info( |
| 100 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " |
| 101 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." |
no test coverage detected