MCPcopy
hub / github.com/zai-org/ChatGLM2-6B / main

Function main

ptuning/main.py:49–402  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

47logger = logging.getLogger(__name__)
48
49def main():
50 parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
51 if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
52 # If we pass only one argument to the script and it's the path to a json file,
53 # let's parse it to get our arguments.
54 model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
55 else:
56 model_args, data_args, training_args = parser.parse_args_into_dataclasses()
57
58 # Setup logging
59 logging.basicConfig(
60 format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
61 datefmt="%m/%d/%Y %H:%M:%S",
62 handlers=[logging.StreamHandler(sys.stdout)],
63 )
64
65 if training_args.should_log:
66 # The default of training_args.log_level is passive, so we set log level at info here to have that default.
67 transformers.utils.logging.set_verbosity_info()
68
69 log_level = training_args.get_process_log_level()
70 logger.setLevel(log_level)
71 # datasets.utils.logging.set_verbosity(log_level)
72 transformers.utils.logging.set_verbosity(log_level)
73 transformers.utils.logging.enable_default_handler()
74 transformers.utils.logging.enable_explicit_format()
75
76 # Log on each process the small summary:
77 logger.warning(
78 f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
79 + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
80 )
81 logger.info(f"Training/evaluation parameters {training_args}")
82
83 # Set seed before initializing model.
84 set_seed(training_args.seed)
85
86 # Load dataset
87 data_files = {}
88 if data_args.train_file is not None:
89 data_files["train"] = data_args.train_file
90 extension = data_args.train_file.split(".")[-1]
91 if data_args.validation_file is not None:
92 data_files["validation"] = data_args.validation_file
93 extension = data_args.validation_file.split(".")[-1]
94 if data_args.test_file is not None:
95 data_files["test"] = data_args.test_file
96 extension = data_args.test_file.split(".")[-1]
97
98 raw_datasets = load_dataset(
99 extension,
100 data_files=data_files,
101 cache_dir=model_args.cache_dir,
102 use_auth_token=True if model_args.use_auth_token else None,
103 )
104
105 # Load pretrained model and tokenizer
106 config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)

Callers 2

_mp_fnFunction · 0.70
main.pyFile · 0.70

Calls 4

evaluateMethod · 0.95
predictMethod · 0.95
Seq2SeqTrainerClass · 0.90
print_dataset_exampleFunction · 0.85

Tested by

no test coverage detected