| 45 | |
| 46 | |
| 47 | def make_data_loader(dataset, batch_size, args): |
| 48 | |
| 49 | shuffle = args.shuffle |
| 50 | if shuffle: |
| 51 | sampler = data_utils.samplers.RandomSampler(dataset, replacement=True, num_samples=batch_size*args.train_iters) |
| 52 | else: |
| 53 | sampler = torch.utils.data.SequentialSampler(dataset) |
| 54 | world_size = torch.distributed.get_world_size( |
| 55 | group=mpu.get_data_parallel_group()) |
| 56 | rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group()) |
| 57 | distributed = world_size > 1 |
| 58 | drop_last = distributed |
| 59 | |
| 60 | if distributed: |
| 61 | batch_sampler = data_utils.samplers.DistributedBatchSampler(sampler, |
| 62 | batch_size, |
| 63 | drop_last, |
| 64 | rank, |
| 65 | world_size) |
| 66 | else: |
| 67 | batch_sampler = torch.utils.data.BatchSampler(sampler, |
| 68 | batch_size, |
| 69 | drop_last) |
| 70 | |
| 71 | data_loader = torch.utils.data.DataLoader(dataset, |
| 72 | batch_sampler=batch_sampler, |
| 73 | num_workers=args.num_workers, |
| 74 | pin_memory=True) |
| 75 | |
| 76 | return data_loader |
| 77 | |
| 78 | |
| 79 | def make_tfrecord_loaders(args): |