MCPcopy Index your code
hub / github.com/THUDM/GLM / make_loaders

Function make_loaders

configure_data.py:248–345  ·  view source on GitHub ↗

makes training/val/test

(args, tokenizer)

Source from the content-addressed store, hash-verified

246
247
248def make_loaders(args, tokenizer):
249 """makes training/val/test"""
250
251 if args.use_tfrecords:
252 return make_tfrecord_loaders(args)
253 world_size = torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
254 if args.loader_scatter is not None:
255 assert world_size % args.loader_scatter == 0
256 batch_size = args.batch_size * world_size
257 eval_batch_size = batch_size
258 if args.eval_batch_size is not None:
259 eval_batch_size = args.eval_batch_size * world_size
260 seq_length = args.seq_length
261 if seq_length < 0:
262 seq_length = seq_length * world_size
263 eval_seq_length = args.eval_seq_length
264 if eval_seq_length is not None and eval_seq_length < 0:
265 eval_seq_length = eval_seq_length * world_size
266 split = get_split(args)
267 data_set_args = {
268 'path': args.train_data,
269 'seq_length': seq_length,
270 'mem_length': args.mem_length,
271 'delim': args.delim,
272 'text_key': args.text_key,
273 'label_key': 'label',
274 'ds_type': args.data_set_type,
275 'split': split,
276 'loose': args.loose_json,
277 'max_preds_per_seq': args.max_preds_per_seq,
278 'presplit_sentences': args.presplit_sentences,
279 'sample_one_document': args.sample_one_document,
280 'filter_english': args.filter_english,
281 'pre_tokenize': not args.no_pre_tokenize,
282 'tokenizer': tokenizer,
283 'save_splits': args.save_splits,
284 'load_splits': args.load_splits,
285 'save_test_data': args.save_test_data,
286 'no_lazy_loader': args.no_lazy_loader,
287 'loader_scatter': args.loader_scatter,
288 'data_parallel_rank': mpu.get_data_parallel_rank(),
289 "non_sentence_start": args.non_sentence_start,
290 "half_lazy_loader": args.half_lazy_loader
291 }
292
293 eval_set_args = copy.copy(data_set_args)
294 eval_set_args['split'] = [1.]
295 # if optional eval args were set then replace their
296 # equivalent values in the arg dict
297 if eval_seq_length:
298 eval_set_args['seq_length'] = eval_seq_length
299 if args.eval_max_preds_per_seq:
300 eval_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq
301 if args.eval_text_key is not None:
302 eval_set_args['text_key'] = args.eval_text_key
303
304 # make datasets splits and tokenizer
305 train, valid, test = None, None, None

Callers 1

applyMethod · 0.85

Calls 3

make_tfrecord_loadersFunction · 0.85
get_splitFunction · 0.85
make_data_loaderFunction · 0.85

Tested by

no test coverage detected