()
| 45 | |
| 46 | |
| 47 | def main(): |
| 48 | parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig)) |
| 49 | model_args, data_args, training_args = parser.parse() |
| 50 | |
| 51 | # Set seed for reproducibility |
| 52 | set_seed(training_args.seed) |
| 53 | |
| 54 | ############### |
| 55 | # Setup logging |
| 56 | ############### |
| 57 | logging.basicConfig( |
| 58 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| 59 | datefmt="%Y-%m-%d %H:%M:%S", |
| 60 | handlers=[logging.StreamHandler(sys.stdout)], |
| 61 | ) |
| 62 | log_level = training_args.get_process_log_level() |
| 63 | logger.setLevel(log_level) |
| 64 | datasets.utils.logging.set_verbosity(log_level) |
| 65 | transformers.utils.logging.set_verbosity(log_level) |
| 66 | transformers.utils.logging.enable_default_handler() |
| 67 | transformers.utils.logging.enable_explicit_format() |
| 68 | |
| 69 | # Log on each process a small summary |
| 70 | logger.warning( |
| 71 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" |
| 72 | + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" |
| 73 | ) |
| 74 | logger.info(f"Model parameters {model_args}") |
| 75 | logger.info(f"Data parameters {data_args}") |
| 76 | logger.info(f"Training/evaluation parameters {training_args}") |
| 77 | |
| 78 | # Check for last checkpoint |
| 79 | last_checkpoint = get_checkpoint(training_args) |
| 80 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None: |
| 81 | logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") |
| 82 | |
| 83 | ############### |
| 84 | # Load datasets |
| 85 | ############### |
| 86 | raw_datasets = get_datasets( |
| 87 | data_args, |
| 88 | splits=data_args.dataset_splits, |
| 89 | configs=data_args.dataset_configs, |
| 90 | columns_to_keep=[data_args.text_column], |
| 91 | ) |
| 92 | |
| 93 | logger.info( |
| 94 | f"Training on the following datasets and their proportions:" |
| 95 | f" {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}" |
| 96 | ) |
| 97 | |
| 98 | train_dataset = raw_datasets["train"] if "train" in raw_datasets else None |
| 99 | eval_dataset = raw_datasets["test"] if "test" in raw_datasets else None |
| 100 | |
| 101 | if train_dataset is None: |
| 102 | raise ValueError( |
| 103 | "Training set must be included (so make sure that your dataset has a split with" " 'train' in the name)." |
| 104 | ) |
no test coverage detected