(data_path)
| 40 | num_workers = args.num_workers |
| 41 | |
| 42 | def make_data_loader_(data_path): |
| 43 | # Build the dataset. |
| 44 | dataset = GPT2Dataset(data_path, input_data_sizes_file, |
| 45 | seq_length, initial_seed) |
| 46 | # Use a simple sampler with distributed batch sampler. |
| 47 | sampler = torch.utils.data.SequentialSampler(dataset) |
| 48 | batch_sampler = DistributedBatchSampler(sampler=sampler, |
| 49 | batch_size=global_batch_size, |
| 50 | drop_last=True, |
| 51 | rank=rank, |
| 52 | world_size=world_size) |
| 53 | # Torch dataloader. |
| 54 | return torch.utils.data.DataLoader(dataset, |
| 55 | batch_sampler=batch_sampler, |
| 56 | num_workers=num_workers, |
| 57 | pin_memory=True) |
| 58 | |
| 59 | train = make_data_loader_(args.train_data_path) |
| 60 | valid = make_data_loader_(args.val_data_path) |
no test coverage detected