Load the data on rank zero and boradcast number of tokens to all GPUS.
(args)
| 678 | |
| 679 | |
| 680 | def get_train_val_test_data(args): |
| 681 | """Load the data on rank zero and boradcast number of tokens to all GPUS.""" |
| 682 | |
| 683 | (train_data, val_data, test_data) = (None, None, None) |
| 684 | |
| 685 | # Data loader only on rank 0 of each model parallel group. |
| 686 | if mpu.get_model_parallel_rank() == 0: |
| 687 | train_data, val_data, test_data = make_loaders(args) |
| 688 | num_tokens = get_tokenizer().num_tokens |
| 689 | |
| 690 | before = num_tokens |
| 691 | after = before |
| 692 | multiple = args.make_vocab_size_divisible_by * \ |
| 693 | mpu.get_model_parallel_world_size() |
| 694 | while (after % multiple) != 0: |
| 695 | after += 1 |
| 696 | print_rank_0('> padded vocab (size: {}) with {} dummy ' |
| 697 | 'tokens (new size: {})'.format( |
| 698 | before, after - before, after)) |
| 699 | token_counts = torch.cuda.LongTensor( |
| 700 | [after, int(args.do_train), int(args.do_valid), int(args.do_test)]) |
| 701 | else: |
| 702 | token_counts = torch.cuda.LongTensor([0, 0, 0, 0]) |
| 703 | # Broadcast num tokens. |
| 704 | torch.distributed.broadcast(token_counts, |
| 705 | mpu.get_model_parallel_src_rank(), |
| 706 | group=mpu.get_model_parallel_group()) |
| 707 | num_tokens = token_counts[0].item() |
| 708 | args.do_train = token_counts[1].item() |
| 709 | args.do_valid = token_counts[2].item() |
| 710 | args.do_test = token_counts[3].item() |
| 711 | |
| 712 | return train_data, val_data, test_data, num_tokens |
| 713 | |
| 714 | |
| 715 | def main(): |
no test coverage detected