MCPcopy
hub / github.com/huggingface/alignment-handbook / main

Function main

scripts/run_dpo.py:46–257  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

44
45
46def main():
47 parser = H4ArgumentParser((ModelArguments, DataArguments, DPOConfig))
48 model_args, data_args, training_args = parser.parse()
49
50 #######
51 # Setup
52 #######
53 logging.basicConfig(
54 format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
55 datefmt="%Y-%m-%d %H:%M:%S",
56 handlers=[logging.StreamHandler(sys.stdout)],
57 )
58 log_level = training_args.get_process_log_level()
59 logger.setLevel(log_level)
60 transformers.utils.logging.set_verbosity(log_level)
61 transformers.utils.logging.enable_default_handler()
62 transformers.utils.logging.enable_explicit_format()
63
64 # Log on each process the small summary:
65 logger.info(f"Model parameters {model_args}")
66 logger.info(f"Data parameters {data_args}")
67 logger.info(f"Training/evaluation parameters {training_args}")
68
69 # Check for last checkpoint
70 last_checkpoint = get_checkpoint(training_args)
71 if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
72 logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
73
74 # Set seed for reproducibility
75 set_seed(training_args.seed)
76
77 ###############
78 # Load datasets
79 ###############
80 raw_datasets = get_datasets(
81 data_args,
82 splits=data_args.dataset_splits,
83 configs=data_args.dataset_configs,
84 columns_to_keep=["messages", "chosen", "rejected", "prompt", "completion", "label"],
85 )
86 logger.info(
87 f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
88 )
89 column_names = list(raw_datasets["train"].features)
90
91 #####################################
92 # Load tokenizer and process datasets
93 #####################################
94 data_args.truncation_side = "left" # Truncate from left to ensure we don't lose labels in final turn
95 tokenizer = get_tokenizer(model_args, data_args)
96
97 #####################
98 # Apply chat template
99 #####################
100 raw_datasets = raw_datasets.map(
101 apply_chat_template,
102 fn_kwargs={
103 "tokenizer": tokenizer,

Callers 1

run_dpo.pyFile · 0.70

Calls 9

parseMethod · 0.95
H4ArgumentParserClass · 0.90
get_checkpointFunction · 0.90
get_datasetsFunction · 0.90
get_tokenizerFunction · 0.90
get_quantization_configFunction · 0.90
get_kbit_device_mapFunction · 0.90
is_adapter_modelFunction · 0.90
get_peft_configFunction · 0.90

Tested by

no test coverage detected