MCPcopy
hub / github.com/zai-org/CogView / main

Function main

pretrain_gpt2.py:715–815  ·  view source on GitHub ↗

Main training program.

()

Source from the content-addressed store, hash-verified

713
714
715def main():
716 """Main training program."""
717
718 # Disable CuDNN.
719 torch.backends.cudnn.enabled = False
720 # Timer.
721 timers = Timers()
722
723 # Arguments.
724 args = get_args()
725 if args.load:
726 args.experiment_name = os.path.basename(os.path.normpath(args.load))
727 else:
728 args.experiment_name = args.experiment_name + datetime.now().strftime("%m-%d-%H-%M")
729 if args.save:
730 args.save = os.path.join(args.save, args.experiment_name)
731 # Pytorch distributed.
732 initialize_distributed(args)
733
734 # Random seeds for reproducability.
735 set_random_seed(args.seed)
736
737 # init tokenizer
738 tokenizer = get_tokenizer(args)
739
740 # Data stuff.
741 train_data, val_data, test_data, args.vocab_size = get_train_val_test_data(args)
742
743 # Model, optimizer, and learning rate.
744 model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
745
746 if args.load is not None:
747 if args.fast_load:
748 args.iteration = load_checkpoint(model, optimizer, lr_scheduler, args)
749 else:
750 with FileLock("/root/checkpoint_lock", timeout=-1):
751 args.iteration = load_checkpoint(model, optimizer, lr_scheduler, args)
752 else:
753 args.iteration = 0
754 torch.distributed.barrier()
755
756 summary_writer = None
757 if torch.distributed.get_rank() == 0:
758 if args.finetune:
759 print('Finetune CogView model')
760 else:
761 print('Pretrain CogView model')
762 print_args(args)
763 summary_writer = get_sample_writer(base=args.summary_dir, name=args.experiment_name, iteration=args.iteration)
764
765 # Resume data loader if necessary.
766 if args.resume_dataloader:
767 if train_data is not None:
768 train_data.batch_sampler.start_iter = args.iteration % \
769 len(train_data)
770 if val_data is not None:
771 start_iter_val = (args.train_iters // args.save_interval) * \
772 args.eval_interval

Callers 1

pretrain_gpt2.pyFile · 0.70

Calls 13

TimersClass · 0.90
get_argsFunction · 0.90
get_tokenizerFunction · 0.90
load_checkpointFunction · 0.90
print_argsFunction · 0.90
get_sample_writerFunction · 0.90
save_checkpointFunction · 0.90
initialize_distributedFunction · 0.85
set_random_seedFunction · 0.85
get_train_val_test_dataFunction · 0.85
trainFunction · 0.85

Tested by

no test coverage detected