MCPcopy
hub / github.com/deepspeedai/DeepSpeedExamples / get_train_val_test_data

Function get_train_val_test_data

Megatron-LM/pretrain_bert.py:466–502  ·  view source on GitHub ↗

Load the data on rank zero and boradcast number of tokens to all GPUS.

(args)

Source from the content-addressed store, hash-verified

464
465
466def get_train_val_test_data(args):
467 """Load the data on rank zero and boradcast number of tokens to all GPUS."""
468
469 (train_data, val_data, test_data) = (None, None, None)
470
471 # Data loader only on rank 0 of each model parallel group.
472 if mpu.get_model_parallel_rank() == 0:
473 data_config = configure_data()
474 data_config.set_defaults(data_set_type='BERT', transpose=False)
475 (train_data, val_data, test_data), tokenizer = data_config.apply(args)
476 before = tokenizer.num_tokens
477 after = before
478 multiple = args.make_vocab_size_divisible_by * \
479 mpu.get_model_parallel_world_size()
480 while (after % multiple) != 0:
481 after += 1
482 print_rank_0('> padded vocab (size: {}) with {} dummy '
483 'tokens (new size: {})'.format(
484 before, after - before, after))
485 # Need to broadcast num_tokens and num_type_tokens.
486 token_counts = torch.cuda.LongTensor([after,
487 tokenizer.num_type_tokens,
488 int(args.do_train), int(args.do_valid), int(args.do_test)])
489 else:
490 token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])
491
492 # Broadcast num tokens.
493 torch.distributed.broadcast(token_counts,
494 mpu.get_model_parallel_src_rank(),
495 group=mpu.get_model_parallel_group())
496 num_tokens = token_counts[0].item()
497 num_type_tokens = token_counts[1].item()
498 args.do_train = token_counts[2].item()
499 args.do_valid = token_counts[3].item()
500 args.do_test = token_counts[4].item()
501
502 return train_data, val_data, test_data, num_tokens, num_type_tokens
503
504
505def main():

Callers 1

mainFunction · 0.70

Calls 4

configure_dataFunction · 0.90
print_rank_0Function · 0.90
set_defaultsMethod · 0.80
applyMethod · 0.80

Tested by

no test coverage detected