Load the data on rank zero and boradcast number of tokens to all GPUS.
(args)
| 550 | |
| 551 | |
| 552 | def get_train_val_test_data(args): |
| 553 | """Load the data on rank zero and boradcast number of tokens to all GPUS.""" |
| 554 | |
| 555 | (train_data, val_data, test_data) = (None, None, None) |
| 556 | |
| 557 | # Data loader only on rank 0 of each model parallel group. |
| 558 | if mpu.get_model_parallel_rank() == 0: |
| 559 | if args.use_npy_data_loader: |
| 560 | (train_data, val_data, test_data), num_tokens, \ |
| 561 | eod_token = make_gpt2_dataloaders(args) |
| 562 | else: |
| 563 | data_config = configure_data() |
| 564 | data_config.set_defaults(data_set_type='GPT2', transpose=False) |
| 565 | (train_data, val_data, test_data), tokenizer = data_config.apply( |
| 566 | args) |
| 567 | num_tokens = tokenizer.num_tokens |
| 568 | eod_token = tokenizer.get_command('eos').Id |
| 569 | assert eod_token == tokenizer.get_command('pad').Id |
| 570 | before = num_tokens |
| 571 | after = before |
| 572 | multiple = args.make_vocab_size_divisible_by * \ |
| 573 | mpu.get_model_parallel_world_size() |
| 574 | while (after % multiple) != 0: |
| 575 | after += 1 |
| 576 | print_rank_0('> padded vocab (size: {}) with {} dummy ' |
| 577 | 'tokens (new size: {})'.format( |
| 578 | before, after - before, after)) |
| 579 | print_rank_0('> found end-of-document token: {}'.format(eod_token)) |
| 580 | token_counts = torch.cuda.LongTensor([after, eod_token, int(args.do_train), int(args.do_valid), int(args.do_test)]) |
| 581 | else: |
| 582 | token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0]) |
| 583 | |
| 584 | # Broadcast num tokens. |
| 585 | torch.distributed.broadcast(token_counts, |
| 586 | mpu.get_model_parallel_src_rank(), |
| 587 | group=mpu.get_model_parallel_group()) |
| 588 | num_tokens = token_counts[0].item() |
| 589 | eod_token = token_counts[1].item() |
| 590 | args.do_train = token_counts[2].item() |
| 591 | args.do_valid = token_counts[3].item() |
| 592 | args.do_test = token_counts[4].item() |
| 593 | |
| 594 | return train_data, val_data, test_data, num_tokens, eod_token |
| 595 | |
| 596 | |
| 597 | def main(): |
no test coverage detected