MCPcopy
hub / github.com/zai-org/CogView / train

Function train

pretrain_gpt2.py:482–566  ·  view source on GitHub ↗

Train the model.

(model, optimizer, lr_scheduler,
          train_data_iterator, val_data_iterator, timers, args, summary_writer=None)

Source from the content-addressed store, hash-verified

480
481
482def train(model, optimizer, lr_scheduler,
483 train_data_iterator, val_data_iterator, timers, args, summary_writer=None):
484 """Train the model."""
485 # Turn on training mode which enables dropout.
486 model.train()
487
488 # Tracking loss.
489 total_lm_loss = 0.0
490 total_img_loss = total_txt_loss = 0.0
491
492 # Iterations.
493 skipped_iters = 0
494
495 timers('interval time').start()
496 report_memory_flag = True
497 mems = []
498 while args.iteration < args.train_iters:
499
500 if args.iteration % 100 == 0:
501 new_loaders = detect_new_datasets(args)
502 if new_loaders is not None:
503 print(f'Loatding new datasets ... Now we train models on {args.train_data}.')
504 train_data_iterator = iter(new_loaders[0])
505 val_data_iterator = iter(new_loaders[1])
506 # TODO close the original
507
508
509 lm_loss, skipped_iter, mems, img_loss, txt_loss = train_step(train_data_iterator,
510 model,
511 optimizer,
512 lr_scheduler,
513 args, timers, mems)
514 skipped_iters += skipped_iter
515 args.iteration += 1
516
517 # Update losses.
518 total_lm_loss += lm_loss.data.detach().float()
519 total_img_loss += img_loss.data.detach().float()
520 total_txt_loss += txt_loss.data.detach().float()
521
522 # Logging.
523 if args.iteration % args.log_interval == 0:
524 learning_rate = optimizer.param_groups[0]['lr']
525 avg_lm_loss = total_lm_loss.item() / args.log_interval
526 # average img & txt loss
527 avg_img_loss = total_img_loss.item() / args.log_interval
528 avg_txt_loss = total_txt_loss.item() / args.log_interval
529
530 elapsed_time = timers('interval time').elapsed()
531 report_iteration_metrics(summary_writer, optimizer, learning_rate, avg_lm_loss,
532 elapsed_time * 1000.0 / args.log_interval, args.iteration, args.train_iters, args,
533 avg_img_loss, avg_txt_loss)
534 total_lm_loss = 0.0
535 total_img_loss = 0.0
536 total_txt_loss = 0.0
537 if report_memory_flag:
538 report_memory('after {} iterations'.format(args.iteration))
539 report_memory_flag = False

Callers 1

mainFunction · 0.85

Calls 9

detect_new_datasetsFunction · 0.90
report_memoryFunction · 0.90
save_checkpointFunction · 0.90
train_stepFunction · 0.85
report_iteration_metricsFunction · 0.85
startMethod · 0.80
elapsedMethod · 0.80
logMethod · 0.80

Tested by

no test coverage detected