MCPcopy
hub / github.com/huggingface/alignment-handbook / main

Function main

scripts/run_cpt.py:47–205  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

45
46
47def main():
48 parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig))
49 model_args, data_args, training_args = parser.parse()
50
51 # Set seed for reproducibility
52 set_seed(training_args.seed)
53
54 ###############
55 # Setup logging
56 ###############
57 logging.basicConfig(
58 format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
59 datefmt="%Y-%m-%d %H:%M:%S",
60 handlers=[logging.StreamHandler(sys.stdout)],
61 )
62 log_level = training_args.get_process_log_level()
63 logger.setLevel(log_level)
64 datasets.utils.logging.set_verbosity(log_level)
65 transformers.utils.logging.set_verbosity(log_level)
66 transformers.utils.logging.enable_default_handler()
67 transformers.utils.logging.enable_explicit_format()
68
69 # Log on each process a small summary
70 logger.warning(
71 f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
72 + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
73 )
74 logger.info(f"Model parameters {model_args}")
75 logger.info(f"Data parameters {data_args}")
76 logger.info(f"Training/evaluation parameters {training_args}")
77
78 # Check for last checkpoint
79 last_checkpoint = get_checkpoint(training_args)
80 if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
81 logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
82
83 ###############
84 # Load datasets
85 ###############
86 raw_datasets = get_datasets(
87 data_args,
88 splits=data_args.dataset_splits,
89 configs=data_args.dataset_configs,
90 columns_to_keep=[data_args.text_column],
91 )
92
93 logger.info(
94 f"Training on the following datasets and their proportions:"
95 f" {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
96 )
97
98 train_dataset = raw_datasets["train"] if "train" in raw_datasets else None
99 eval_dataset = raw_datasets["test"] if "test" in raw_datasets else None
100
101 if train_dataset is None:
102 raise ValueError(
103 "Training set must be included (so make sure that your dataset has a split with" " 'train' in the name)."
104 )

Callers 1

run_cpt.pyFile · 0.70

Calls 8

parseMethod · 0.95
H4ArgumentParserClass · 0.90
get_checkpointFunction · 0.90
get_datasetsFunction · 0.90
get_tokenizerFunction · 0.90
get_quantization_configFunction · 0.90
get_kbit_device_mapFunction · 0.90
get_peft_configFunction · 0.90

Tested by

no test coverage detected