Main training program.
()
| 713 | |
| 714 | |
| 715 | def main(): |
| 716 | """Main training program.""" |
| 717 | |
| 718 | # Disable CuDNN. |
| 719 | torch.backends.cudnn.enabled = False |
| 720 | # Timer. |
| 721 | timers = Timers() |
| 722 | |
| 723 | # Arguments. |
| 724 | args = get_args() |
| 725 | if args.load: |
| 726 | args.experiment_name = os.path.basename(os.path.normpath(args.load)) |
| 727 | else: |
| 728 | args.experiment_name = args.experiment_name + datetime.now().strftime("%m-%d-%H-%M") |
| 729 | if args.save: |
| 730 | args.save = os.path.join(args.save, args.experiment_name) |
| 731 | # Pytorch distributed. |
| 732 | initialize_distributed(args) |
| 733 | |
| 734 | # Random seeds for reproducability. |
| 735 | set_random_seed(args.seed) |
| 736 | |
| 737 | # init tokenizer |
| 738 | tokenizer = get_tokenizer(args) |
| 739 | |
| 740 | # Data stuff. |
| 741 | train_data, val_data, test_data, args.vocab_size = get_train_val_test_data(args) |
| 742 | |
| 743 | # Model, optimizer, and learning rate. |
| 744 | model, optimizer, lr_scheduler = setup_model_and_optimizer(args) |
| 745 | |
| 746 | if args.load is not None: |
| 747 | if args.fast_load: |
| 748 | args.iteration = load_checkpoint(model, optimizer, lr_scheduler, args) |
| 749 | else: |
| 750 | with FileLock("/root/checkpoint_lock", timeout=-1): |
| 751 | args.iteration = load_checkpoint(model, optimizer, lr_scheduler, args) |
| 752 | else: |
| 753 | args.iteration = 0 |
| 754 | torch.distributed.barrier() |
| 755 | |
| 756 | summary_writer = None |
| 757 | if torch.distributed.get_rank() == 0: |
| 758 | if args.finetune: |
| 759 | print('Finetune CogView model') |
| 760 | else: |
| 761 | print('Pretrain CogView model') |
| 762 | print_args(args) |
| 763 | summary_writer = get_sample_writer(base=args.summary_dir, name=args.experiment_name, iteration=args.iteration) |
| 764 | |
| 765 | # Resume data loader if necessary. |
| 766 | if args.resume_dataloader: |
| 767 | if train_data is not None: |
| 768 | train_data.batch_sampler.start_iter = args.iteration % \ |
| 769 | len(train_data) |
| 770 | if val_data is not None: |
| 771 | start_iter_val = (args.train_iters // args.save_interval) * \ |
| 772 | args.eval_interval |
no test coverage detected