MCPcopy
hub / github.com/huggingface/alignment-handbook / main

Function main

scripts/run_orpo.py:45–266  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

43
44
45def main():
46 parser = H4ArgumentParser((ModelArguments, DataArguments, ORPOConfig))
47 model_args, data_args, training_args = parser.parse()
48
49 #######
50 # Setup
51 #######
52 logging.basicConfig(
53 format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
54 datefmt="%Y-%m-%d %H:%M:%S",
55 handlers=[logging.StreamHandler(sys.stdout)],
56 )
57 log_level = training_args.get_process_log_level()
58 logger.setLevel(log_level)
59 transformers.utils.logging.set_verbosity(log_level)
60 transformers.utils.logging.enable_default_handler()
61 transformers.utils.logging.enable_explicit_format()
62
63 # Log on each process the small summary:
64 logger.info(f"Model parameters {model_args}")
65 logger.info(f"Data parameters {data_args}")
66 logger.info(f"Training/evaluation parameters {training_args}")
67
68 # Check for last checkpoint
69 last_checkpoint = get_checkpoint(training_args)
70 if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
71 logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
72
73 # Set seed for reproducibility
74 set_seed(training_args.seed)
75
76 ###############
77 # Load datasets
78 ###############
79 raw_datasets = get_datasets(
80 data_args,
81 splits=data_args.dataset_splits,
82 configs=data_args.dataset_configs,
83 columns_to_keep=[
84 "prompt",
85 "chosen",
86 "rejected",
87 ],
88 )
89 logger.info(
90 f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
91 )
92 column_names = list(raw_datasets["train"].features)
93
94 #####################################
95 # Load tokenizer and process datasets
96 #####################################
97 data_args.truncation_side = "left" # Truncate from left to ensure we don't lose labels in final turn
98 tokenizer = get_tokenizer(model_args, data_args)
99
100 torch_dtype = (
101 model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
102 )

Callers 1

run_orpo.pyFile · 0.70

Calls 8

parseMethod · 0.95
H4ArgumentParserClass · 0.90
get_checkpointFunction · 0.90
get_datasetsFunction · 0.90
get_tokenizerFunction · 0.90
get_quantization_configFunction · 0.90
get_kbit_device_mapFunction · 0.90
get_peft_configFunction · 0.90

Tested by

no test coverage detected