Load the data on rank zero and boradcast number of tokens to all GPUS.
(args)
| 464 | |
| 465 | |
| 466 | def get_train_val_test_data(args): |
| 467 | """Load the data on rank zero and boradcast number of tokens to all GPUS.""" |
| 468 | |
| 469 | (train_data, val_data, test_data) = (None, None, None) |
| 470 | |
| 471 | # Data loader only on rank 0 of each model parallel group. |
| 472 | if mpu.get_model_parallel_rank() == 0: |
| 473 | data_config = configure_data() |
| 474 | data_config.set_defaults(data_set_type='BERT', transpose=False) |
| 475 | (train_data, val_data, test_data), tokenizer = data_config.apply(args) |
| 476 | before = tokenizer.num_tokens |
| 477 | after = before |
| 478 | multiple = args.make_vocab_size_divisible_by * \ |
| 479 | mpu.get_model_parallel_world_size() |
| 480 | while (after % multiple) != 0: |
| 481 | after += 1 |
| 482 | print_rank_0('> padded vocab (size: {}) with {} dummy ' |
| 483 | 'tokens (new size: {})'.format( |
| 484 | before, after - before, after)) |
| 485 | # Need to broadcast num_tokens and num_type_tokens. |
| 486 | token_counts = torch.cuda.LongTensor([after, |
| 487 | tokenizer.num_type_tokens, |
| 488 | int(args.do_train), int(args.do_valid), int(args.do_test)]) |
| 489 | else: |
| 490 | token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0]) |
| 491 | |
| 492 | # Broadcast num tokens. |
| 493 | torch.distributed.broadcast(token_counts, |
| 494 | mpu.get_model_parallel_src_rank(), |
| 495 | group=mpu.get_model_parallel_group()) |
| 496 | num_tokens = token_counts[0].item() |
| 497 | num_type_tokens = token_counts[1].item() |
| 498 | args.do_train = token_counts[2].item() |
| 499 | args.do_valid = token_counts[3].item() |
| 500 | args.do_test = token_counts[4].item() |
| 501 | |
| 502 | return train_data, val_data, test_data, num_tokens, num_type_tokens |
| 503 | |
| 504 | |
| 505 | def main(): |
no test coverage detected