(dataset, tokenizer, batch_size, num_iters, args, shuffle=False, block_collate=False)
| 152 | |
| 153 | |
| 154 | def make_data_loader(dataset, tokenizer, batch_size, num_iters, args, shuffle=False, block_collate=False): |
| 155 | world_size = torch.distributed.get_world_size(group=mpu.get_data_parallel_group()) |
| 156 | rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group()) |
| 157 | if args.loader_scatter is not None: |
| 158 | rank = rank // args.loader_scatter |
| 159 | world_size = world_size // args.loader_scatter |
| 160 | batch_size = batch_size // args.loader_scatter |
| 161 | distributed = world_size > 1 |
| 162 | if args.transformer_xl: |
| 163 | batch_sampler = data_utils.samplers.DistributedSequentialSampler(len(dataset), |
| 164 | num_iters, |
| 165 | batch_size, |
| 166 | rank, |
| 167 | world_size) |
| 168 | else: |
| 169 | if shuffle: |
| 170 | sampler = data_utils.samplers.RandomSampler(dataset, replacement=True, |
| 171 | num_samples=batch_size * args.train_iters * args.gradient_accumulation_steps) |
| 172 | else: |
| 173 | sampler = torch.utils.data.SequentialSampler(dataset) |
| 174 | drop_last = distributed |
| 175 | # the GPUs in the same model parallel group receive the same data |
| 176 | if distributed: |
| 177 | batch_sampler = data_utils.samplers.DistributedBatchSampler(sampler, batch_size, drop_last, rank, |
| 178 | world_size, |
| 179 | gradient_accumulation_steps=args.gradient_accumulation_steps) |
| 180 | else: |
| 181 | batch_sampler = torch.utils.data.BatchSampler(sampler, |
| 182 | batch_size, |
| 183 | drop_last) |
| 184 | collate_fn = None |
| 185 | if block_collate: |
| 186 | collate_fn = ConstructBlockStrategy(args, tokenizer, args.seq_length, bert_prob=args.bert_prob, |
| 187 | gap_sentence_prob=args.gap_sentence_prob, |
| 188 | gap_sentence_ratio=args.gap_sentence_ratio, |
| 189 | gpt_infill_prob=args.gpt_infill_prob, |
| 190 | average_block_length=args.avg_block_length, |
| 191 | gpt_min_ratio=args.gpt_min_ratio, |
| 192 | block_mask_prob=args.block_mask_prob, |
| 193 | context_mask_ratio=args.context_mask_ratio, |
| 194 | short_seq_prob=args.short_seq_prob, |
| 195 | single_span_prob=args.single_span_prob, |
| 196 | shuffle_blocks=not args.no_shuffle_block, |
| 197 | block_position_encoding=not args.no_block_position, |
| 198 | sentinel_token=args.sentinel_token, |
| 199 | encoder_decoder=args.encoder_decoder, |
| 200 | task_mask=args.task_mask, random_position=args.random_position, |
| 201 | masked_lm=args.masked_lm).construct_blocks |
| 202 | data_loader = torch.utils.data.DataLoader(dataset, |
| 203 | batch_sampler=batch_sampler, |
| 204 | num_workers=args.num_workers, |
| 205 | pin_memory=True, |
| 206 | collate_fn=collate_fn) |
| 207 | |
| 208 | return data_loader |
| 209 | |
| 210 | |
| 211 | def make_tfrecord_loaders(args): |
no test coverage detected