(args)
| 27 | |
| 28 | |
| 29 | def make_gpt2_dataloaders(args): |
| 30 | |
| 31 | # Input parameters. |
| 32 | input_data_sizes_file = args.input_data_sizes_file |
| 33 | seq_length = args.seq_length |
| 34 | initial_seed = args.seed |
| 35 | |
| 36 | # Data parallel arguments. |
| 37 | world_size = mpu.get_data_parallel_world_size() |
| 38 | rank = mpu.get_data_parallel_rank() |
| 39 | global_batch_size = args.batch_size * world_size |
| 40 | num_workers = args.num_workers |
| 41 | |
| 42 | def make_data_loader_(data_path): |
| 43 | # Build the dataset. |
| 44 | dataset = GPT2Dataset(data_path, input_data_sizes_file, |
| 45 | seq_length, initial_seed) |
| 46 | # Use a simple sampler with distributed batch sampler. |
| 47 | sampler = torch.utils.data.SequentialSampler(dataset) |
| 48 | batch_sampler = DistributedBatchSampler(sampler=sampler, |
| 49 | batch_size=global_batch_size, |
| 50 | drop_last=True, |
| 51 | rank=rank, |
| 52 | world_size=world_size) |
| 53 | # Torch dataloader. |
| 54 | return torch.utils.data.DataLoader(dataset, |
| 55 | batch_sampler=batch_sampler, |
| 56 | num_workers=num_workers, |
| 57 | pin_memory=True) |
| 58 | |
| 59 | train = make_data_loader_(args.train_data_path) |
| 60 | valid = make_data_loader_(args.val_data_path) |
| 61 | test = make_data_loader_(args.test_data_path) |
| 62 | |
| 63 | args.do_train = False |
| 64 | args.do_valid = False |
| 65 | args.do_test = False |
| 66 | |
| 67 | if train is not None: |
| 68 | args.do_train = True |
| 69 | if valid is not None: |
| 70 | args.do_valid = True |
| 71 | if test is not None: |
| 72 | args.do_test = True |
| 73 | |
| 74 | # Tokenizer. |
| 75 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=args.cache_dir) |
| 76 | eod_token = tokenizer.encoder['<|endoftext|>'] |
| 77 | num_tokens = eod_token + 1 |
| 78 | |
| 79 | return (train, valid, test), num_tokens, eod_token |
| 80 | |
| 81 | |
| 82 | class GPT2Dataset(Dataset): |
no test coverage detected