Main training program.
()
| 503 | |
| 504 | |
| 505 | def main(): |
| 506 | """Main training program.""" |
| 507 | |
| 508 | # Disable CuDNN. |
| 509 | torch.backends.cudnn.enabled = False |
| 510 | |
| 511 | # Timer. |
| 512 | timers = Timers() |
| 513 | |
| 514 | # Arguments. |
| 515 | args = get_args() |
| 516 | |
| 517 | # Pytorch distributed. |
| 518 | initialize_distributed(args) |
| 519 | if torch.distributed.get_rank() == 0: |
| 520 | print('Pretrain BERT model') |
| 521 | print_args(args) |
| 522 | |
| 523 | # Random seeds for reproducability. |
| 524 | set_random_seed(args.seed) |
| 525 | |
| 526 | # Data stuff. |
| 527 | train_data, val_data, test_data, args.tokenizer_num_tokens, \ |
| 528 | args.tokenizer_num_type_tokens = get_train_val_test_data(args) |
| 529 | |
| 530 | # Model, optimizer, and learning rate. |
| 531 | model, optimizer, lr_scheduler = setup_model_and_optimizer(args) |
| 532 | |
| 533 | if args.resume_dataloader: |
| 534 | if train_data is not None: |
| 535 | train_data.batch_sampler.start_iter = args.iteration % \ |
| 536 | len(train_data) |
| 537 | if val_data is not None: |
| 538 | start_iter_val = (args.train_iters // args.save_interval) * \ |
| 539 | args.eval_interval |
| 540 | val_data.batch_sampler.start_iter = start_iter_val % \ |
| 541 | len(val_data) |
| 542 | |
| 543 | if train_data is not None: |
| 544 | train_data_iterator = iter(train_data) |
| 545 | else: |
| 546 | train_data_iterator = None |
| 547 | if val_data is not None: |
| 548 | val_data_iterator = iter(val_data) |
| 549 | else: |
| 550 | val_data_iterator = None |
| 551 | |
| 552 | iteration = 0 |
| 553 | if args.train_iters > 0: |
| 554 | if args.do_train: |
| 555 | iteration, skipped = train(model, optimizer, |
| 556 | lr_scheduler, |
| 557 | train_data_iterator, |
| 558 | val_data_iterator, |
| 559 | timers, args) |
| 560 | if args.do_valid: |
| 561 | prefix = 'the end of training for val data' |
| 562 | val_loss = evaluate_and_print_results(prefix, val_data_iterator, |
no test coverage detected