MCPcopy
hub / github.com/zai-org/CogView / get_train_val_test_data

Function get_train_val_test_data

pretrain_gpt2.py:680–712  ·  view source on GitHub ↗

Load the data on rank zero and boradcast number of tokens to all GPUS.

(args)

Source from the content-addressed store, hash-verified

678
679
680def get_train_val_test_data(args):
681 """Load the data on rank zero and boradcast number of tokens to all GPUS."""
682
683 (train_data, val_data, test_data) = (None, None, None)
684
685 # Data loader only on rank 0 of each model parallel group.
686 if mpu.get_model_parallel_rank() == 0:
687 train_data, val_data, test_data = make_loaders(args)
688 num_tokens = get_tokenizer().num_tokens
689
690 before = num_tokens
691 after = before
692 multiple = args.make_vocab_size_divisible_by * \
693 mpu.get_model_parallel_world_size()
694 while (after % multiple) != 0:
695 after += 1
696 print_rank_0('> padded vocab (size: {}) with {} dummy '
697 'tokens (new size: {})'.format(
698 before, after - before, after))
699 token_counts = torch.cuda.LongTensor(
700 [after, int(args.do_train), int(args.do_valid), int(args.do_test)])
701 else:
702 token_counts = torch.cuda.LongTensor([0, 0, 0, 0])
703 # Broadcast num tokens.
704 torch.distributed.broadcast(token_counts,
705 mpu.get_model_parallel_src_rank(),
706 group=mpu.get_model_parallel_group())
707 num_tokens = token_counts[0].item()
708 args.do_train = token_counts[1].item()
709 args.do_valid = token_counts[2].item()
710 args.do_test = token_counts[3].item()
711
712 return train_data, val_data, test_data, num_tokens
713
714
715def main():

Callers 1

mainFunction · 0.85

Calls 3

make_loadersFunction · 0.90
get_tokenizerFunction · 0.90
print_rank_0Function · 0.90

Tested by

no test coverage detected