MCPcopy
hub / github.com/deepspeedai/DeepSpeedExamples / main

Function main

Megatron-LM/pretrain_bert.py:505–577  ·  view source on GitHub ↗

Main training program.

()

Source from the content-addressed store, hash-verified

503
504
505def main():
506 """Main training program."""
507
508 # Disable CuDNN.
509 torch.backends.cudnn.enabled = False
510
511 # Timer.
512 timers = Timers()
513
514 # Arguments.
515 args = get_args()
516
517 # Pytorch distributed.
518 initialize_distributed(args)
519 if torch.distributed.get_rank() == 0:
520 print('Pretrain BERT model')
521 print_args(args)
522
523 # Random seeds for reproducability.
524 set_random_seed(args.seed)
525
526 # Data stuff.
527 train_data, val_data, test_data, args.tokenizer_num_tokens, \
528 args.tokenizer_num_type_tokens = get_train_val_test_data(args)
529
530 # Model, optimizer, and learning rate.
531 model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
532
533 if args.resume_dataloader:
534 if train_data is not None:
535 train_data.batch_sampler.start_iter = args.iteration % \
536 len(train_data)
537 if val_data is not None:
538 start_iter_val = (args.train_iters // args.save_interval) * \
539 args.eval_interval
540 val_data.batch_sampler.start_iter = start_iter_val % \
541 len(val_data)
542
543 if train_data is not None:
544 train_data_iterator = iter(train_data)
545 else:
546 train_data_iterator = None
547 if val_data is not None:
548 val_data_iterator = iter(val_data)
549 else:
550 val_data_iterator = None
551
552 iteration = 0
553 if args.train_iters > 0:
554 if args.do_train:
555 iteration, skipped = train(model, optimizer,
556 lr_scheduler,
557 train_data_iterator,
558 val_data_iterator,
559 timers, args)
560 if args.do_valid:
561 prefix = 'the end of training for val data'
562 val_loss = evaluate_and_print_results(prefix, val_data_iterator,

Callers 1

pretrain_bert.pyFile · 0.70

Calls 10

TimersClass · 0.90
get_argsFunction · 0.90
print_argsFunction · 0.90
save_checkpointFunction · 0.90
initialize_distributedFunction · 0.70
set_random_seedFunction · 0.70
get_train_val_test_dataFunction · 0.70
trainFunction · 0.70

Tested by

no test coverage detected