MCPcopy
hub / github.com/huggingface/alignment-handbook / main

Function main

scripts/run_sft.py:49–229  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

47
48
49def main():
50 parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig))
51 model_args, data_args, training_args = parser.parse()
52
53 # Set seed for reproducibility
54 set_seed(training_args.seed)
55
56 ###############
57 # Setup logging
58 ###############
59 logging.basicConfig(
60 format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
61 datefmt="%Y-%m-%d %H:%M:%S",
62 handlers=[logging.StreamHandler(sys.stdout)],
63 )
64 log_level = training_args.get_process_log_level()
65 logger.setLevel(log_level)
66 datasets.utils.logging.set_verbosity(log_level)
67 transformers.utils.logging.set_verbosity(log_level)
68 transformers.utils.logging.enable_default_handler()
69 transformers.utils.logging.enable_explicit_format()
70
71 # Log on each process a small summary
72 logger.warning(
73 f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
74 + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
75 )
76 logger.info(f"Model parameters {model_args}")
77 logger.info(f"Data parameters {data_args}")
78 logger.info(f"Training/evaluation parameters {training_args}")
79
80 # Check for last checkpoint
81 last_checkpoint = get_checkpoint(training_args)
82 if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
83 logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
84
85 ###############
86 # Load datasets
87 ###############
88 raw_datasets = get_datasets(
89 data_args,
90 splits=data_args.dataset_splits,
91 configs=data_args.dataset_configs,
92 columns_to_keep=["messages", "chosen", "rejected", "prompt", "completion", "label"],
93 )
94 logger.info(
95 f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
96 )
97 column_names = list(raw_datasets["train"].features)
98
99 ################
100 # Load tokenizer
101 ################
102 tokenizer = get_tokenizer(model_args, data_args)
103
104 #######################
105 # Load pretrained model
106 #######################

Callers 1

run_sft.pyFile · 0.70

Calls 8

parseMethod · 0.95
H4ArgumentParserClass · 0.90
get_checkpointFunction · 0.90
get_datasetsFunction · 0.90
get_tokenizerFunction · 0.90
get_quantization_configFunction · 0.90
get_kbit_device_mapFunction · 0.90
get_peft_configFunction · 0.90

Tested by

no test coverage detected