()
| 47 | logger = logging.getLogger(__name__) |
| 48 | |
| 49 | def main(): |
| 50 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) |
| 51 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): |
| 52 | # If we pass only one argument to the script and it's the path to a json file, |
| 53 | # let's parse it to get our arguments. |
| 54 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
| 55 | else: |
| 56 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
| 57 | |
| 58 | # Setup logging |
| 59 | logging.basicConfig( |
| 60 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| 61 | datefmt="%m/%d/%Y %H:%M:%S", |
| 62 | handlers=[logging.StreamHandler(sys.stdout)], |
| 63 | ) |
| 64 | |
| 65 | if training_args.should_log: |
| 66 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. |
| 67 | transformers.utils.logging.set_verbosity_info() |
| 68 | |
| 69 | log_level = training_args.get_process_log_level() |
| 70 | logger.setLevel(log_level) |
| 71 | # datasets.utils.logging.set_verbosity(log_level) |
| 72 | transformers.utils.logging.set_verbosity(log_level) |
| 73 | transformers.utils.logging.enable_default_handler() |
| 74 | transformers.utils.logging.enable_explicit_format() |
| 75 | |
| 76 | # Log on each process the small summary: |
| 77 | logger.warning( |
| 78 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" |
| 79 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" |
| 80 | ) |
| 81 | logger.info(f"Training/evaluation parameters {training_args}") |
| 82 | |
| 83 | # Set seed before initializing model. |
| 84 | set_seed(training_args.seed) |
| 85 | |
| 86 | # Load dataset |
| 87 | data_files = {} |
| 88 | if data_args.train_file is not None: |
| 89 | data_files["train"] = data_args.train_file |
| 90 | extension = data_args.train_file.split(".")[-1] |
| 91 | if data_args.validation_file is not None: |
| 92 | data_files["validation"] = data_args.validation_file |
| 93 | extension = data_args.validation_file.split(".")[-1] |
| 94 | if data_args.test_file is not None: |
| 95 | data_files["test"] = data_args.test_file |
| 96 | extension = data_args.test_file.split(".")[-1] |
| 97 | |
| 98 | raw_datasets = load_dataset( |
| 99 | extension, |
| 100 | data_files=data_files, |
| 101 | cache_dir=model_args.cache_dir, |
| 102 | use_auth_token=True if model_args.use_auth_token else None, |
| 103 | ) |
| 104 | |
| 105 | # Load pretrained model and tokenizer |
| 106 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) |
no test coverage detected