MCPcopy
hub / github.com/lm-sys/FastChat / train

Function train

fastchat/train/train.py:256–314  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

254
255
256def train():
257 global local_rank
258
259 parser = transformers.HfArgumentParser(
260 (ModelArguments, DataArguments, TrainingArguments)
261 )
262 model_args, data_args, training_args = parser.parse_args_into_dataclasses()
263 local_rank = training_args.local_rank
264
265 # Set RoPE scaling factor
266 config = transformers.AutoConfig.from_pretrained(
267 model_args.model_name_or_path,
268 cache_dir=training_args.cache_dir,
269 trust_remote_code=model_args.trust_remote_code,
270 )
271 orig_ctx_len = getattr(config, "max_position_embeddings", None)
272 if orig_ctx_len and training_args.model_max_length > orig_ctx_len:
273 scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
274 config.rope_scaling = {"type": "linear", "factor": scaling_factor}
275 config.use_cache = False
276
277 # Load model and tokenizer
278 model = transformers.AutoModelForCausalLM.from_pretrained(
279 model_args.model_name_or_path,
280 config=config,
281 cache_dir=training_args.cache_dir,
282 trust_remote_code=model_args.trust_remote_code,
283 )
284 tokenizer = transformers.AutoTokenizer.from_pretrained(
285 model_args.model_name_or_path,
286 cache_dir=training_args.cache_dir,
287 model_max_length=training_args.model_max_length,
288 padding_side=model_args.padding_side,
289 use_fast=False,
290 trust_remote_code=model_args.trust_remote_code,
291 )
292
293 if tokenizer.pad_token != tokenizer.unk_token:
294 tokenizer.pad_token = tokenizer.unk_token
295
296 # Load data
297 data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
298
299 # Start trainner
300 trainer = Trainer(
301 model=model, tokenizer=tokenizer, args=training_args, **data_module
302 )
303 if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
304 trainer.train(resume_from_checkpoint=True)
305 else:
306 trainer.train()
307
308 # Save model
309 model.config.use_cache = True
310 trainer.save_state()
311 if trainer.is_deepspeed_enabled:
312 trainer.save_model()
313 else:

Callers 3

train_mem.pyFile · 0.90
train_xformers.pyFile · 0.90
train.pyFile · 0.70

Calls 2

trainer_save_model_safeFunction · 0.70

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…