MCPcopy Index your code
hub / github.com/THUDM/GLM / make_data_loader

Function make_data_loader

configure_data.py:154–208  ·  view source on GitHub ↗
(dataset, tokenizer, batch_size, num_iters, args, shuffle=False, block_collate=False)

Source from the content-addressed store, hash-verified

152
153
154def make_data_loader(dataset, tokenizer, batch_size, num_iters, args, shuffle=False, block_collate=False):
155 world_size = torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
156 rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group())
157 if args.loader_scatter is not None:
158 rank = rank // args.loader_scatter
159 world_size = world_size // args.loader_scatter
160 batch_size = batch_size // args.loader_scatter
161 distributed = world_size > 1
162 if args.transformer_xl:
163 batch_sampler = data_utils.samplers.DistributedSequentialSampler(len(dataset),
164 num_iters,
165 batch_size,
166 rank,
167 world_size)
168 else:
169 if shuffle:
170 sampler = data_utils.samplers.RandomSampler(dataset, replacement=True,
171 num_samples=batch_size * args.train_iters * args.gradient_accumulation_steps)
172 else:
173 sampler = torch.utils.data.SequentialSampler(dataset)
174 drop_last = distributed
175 # the GPUs in the same model parallel group receive the same data
176 if distributed:
177 batch_sampler = data_utils.samplers.DistributedBatchSampler(sampler, batch_size, drop_last, rank,
178 world_size,
179 gradient_accumulation_steps=args.gradient_accumulation_steps)
180 else:
181 batch_sampler = torch.utils.data.BatchSampler(sampler,
182 batch_size,
183 drop_last)
184 collate_fn = None
185 if block_collate:
186 collate_fn = ConstructBlockStrategy(args, tokenizer, args.seq_length, bert_prob=args.bert_prob,
187 gap_sentence_prob=args.gap_sentence_prob,
188 gap_sentence_ratio=args.gap_sentence_ratio,
189 gpt_infill_prob=args.gpt_infill_prob,
190 average_block_length=args.avg_block_length,
191 gpt_min_ratio=args.gpt_min_ratio,
192 block_mask_prob=args.block_mask_prob,
193 context_mask_ratio=args.context_mask_ratio,
194 short_seq_prob=args.short_seq_prob,
195 single_span_prob=args.single_span_prob,
196 shuffle_blocks=not args.no_shuffle_block,
197 block_position_encoding=not args.no_block_position,
198 sentinel_token=args.sentinel_token,
199 encoder_decoder=args.encoder_decoder,
200 task_mask=args.task_mask, random_position=args.random_position,
201 masked_lm=args.masked_lm).construct_blocks
202 data_loader = torch.utils.data.DataLoader(dataset,
203 batch_sampler=batch_sampler,
204 num_workers=args.num_workers,
205 pin_memory=True,
206 collate_fn=collate_fn)
207
208 return data_loader
209
210
211def make_tfrecord_loaders(args):

Callers 3

finetuneFunction · 0.90
make_loadersFunction · 0.85
build_multi_task_datasetFunction · 0.85

Calls 1

Tested by

no test coverage detected