()
| 43 | |
| 44 | |
| 45 | def main(): |
| 46 | parser = H4ArgumentParser((ModelArguments, DataArguments, ORPOConfig)) |
| 47 | model_args, data_args, training_args = parser.parse() |
| 48 | |
| 49 | ####### |
| 50 | # Setup |
| 51 | ####### |
| 52 | logging.basicConfig( |
| 53 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| 54 | datefmt="%Y-%m-%d %H:%M:%S", |
| 55 | handlers=[logging.StreamHandler(sys.stdout)], |
| 56 | ) |
| 57 | log_level = training_args.get_process_log_level() |
| 58 | logger.setLevel(log_level) |
| 59 | transformers.utils.logging.set_verbosity(log_level) |
| 60 | transformers.utils.logging.enable_default_handler() |
| 61 | transformers.utils.logging.enable_explicit_format() |
| 62 | |
| 63 | # Log on each process the small summary: |
| 64 | logger.info(f"Model parameters {model_args}") |
| 65 | logger.info(f"Data parameters {data_args}") |
| 66 | logger.info(f"Training/evaluation parameters {training_args}") |
| 67 | |
| 68 | # Check for last checkpoint |
| 69 | last_checkpoint = get_checkpoint(training_args) |
| 70 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None: |
| 71 | logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") |
| 72 | |
| 73 | # Set seed for reproducibility |
| 74 | set_seed(training_args.seed) |
| 75 | |
| 76 | ############### |
| 77 | # Load datasets |
| 78 | ############### |
| 79 | raw_datasets = get_datasets( |
| 80 | data_args, |
| 81 | splits=data_args.dataset_splits, |
| 82 | configs=data_args.dataset_configs, |
| 83 | columns_to_keep=[ |
| 84 | "prompt", |
| 85 | "chosen", |
| 86 | "rejected", |
| 87 | ], |
| 88 | ) |
| 89 | logger.info( |
| 90 | f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}" |
| 91 | ) |
| 92 | column_names = list(raw_datasets["train"].features) |
| 93 | |
| 94 | ##################################### |
| 95 | # Load tokenizer and process datasets |
| 96 | ##################################### |
| 97 | data_args.truncation_side = "left" # Truncate from left to ensure we don't lose labels in final turn |
| 98 | tokenizer = get_tokenizer(model_args, data_args) |
| 99 | |
| 100 | torch_dtype = ( |
| 101 | model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) |
| 102 | ) |
no test coverage detected