MCPcopy
hub / github.com/deepspeedai/DeepSpeedExamples / get_train_val_test_data

Function get_train_val_test_data

Megatron-LM/pretrain_gpt2.py:552–594  ·  view source on GitHub ↗

Load the data on rank zero and boradcast number of tokens to all GPUS.

(args)

Source from the content-addressed store, hash-verified

550
551
552def get_train_val_test_data(args):
553 """Load the data on rank zero and boradcast number of tokens to all GPUS."""
554
555 (train_data, val_data, test_data) = (None, None, None)
556
557 # Data loader only on rank 0 of each model parallel group.
558 if mpu.get_model_parallel_rank() == 0:
559 if args.use_npy_data_loader:
560 (train_data, val_data, test_data), num_tokens, \
561 eod_token = make_gpt2_dataloaders(args)
562 else:
563 data_config = configure_data()
564 data_config.set_defaults(data_set_type='GPT2', transpose=False)
565 (train_data, val_data, test_data), tokenizer = data_config.apply(
566 args)
567 num_tokens = tokenizer.num_tokens
568 eod_token = tokenizer.get_command('eos').Id
569 assert eod_token == tokenizer.get_command('pad').Id
570 before = num_tokens
571 after = before
572 multiple = args.make_vocab_size_divisible_by * \
573 mpu.get_model_parallel_world_size()
574 while (after % multiple) != 0:
575 after += 1
576 print_rank_0('> padded vocab (size: {}) with {} dummy '
577 'tokens (new size: {})'.format(
578 before, after - before, after))
579 print_rank_0('> found end-of-document token: {}'.format(eod_token))
580 token_counts = torch.cuda.LongTensor([after, eod_token, int(args.do_train), int(args.do_valid), int(args.do_test)])
581 else:
582 token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])
583
584 # Broadcast num tokens.
585 torch.distributed.broadcast(token_counts,
586 mpu.get_model_parallel_src_rank(),
587 group=mpu.get_model_parallel_group())
588 num_tokens = token_counts[0].item()
589 eod_token = token_counts[1].item()
590 args.do_train = token_counts[2].item()
591 args.do_valid = token_counts[3].item()
592 args.do_test = token_counts[4].item()
593
594 return train_data, val_data, test_data, num_tokens, eod_token
595
596
597def main():

Callers 1

mainFunction · 0.70

Calls 6

make_gpt2_dataloadersFunction · 0.90
configure_dataFunction · 0.90
print_rank_0Function · 0.90
set_defaultsMethod · 0.80
applyMethod · 0.80
get_commandMethod · 0.80

Tested by

no test coverage detected