MCPcopy Index your code
hub / github.com/THUDM/GLM / build_multi_task_dataset

Function build_multi_task_dataset

configure_data.py:348–373  ·  view source on GitHub ↗
(args, tokenizer)

Source from the content-addressed store, hash-verified

346
347
348def build_multi_task_dataset(args, tokenizer):
349 task_dirs = {"mnli": "MNLI", "cola": "CoLA", "mrpc": "MRPC", "qnli": "QNLI", "qqp": "QQP", "sst2": "SST-2",
350 "agnews": "Agnews", "yelp-polarity": "yelp_review_polarity_csv", "yelp-full": "yelp_review_full_csv",
351 "yahoo": "Yahoo", "squad": "SQuAD", "race": "RACE"}
352 train, valid = None, None
353 if mpu.get_model_parallel_rank() == 0:
354 multi_seq_length = args.seq_length
355 if args.multi_seq_length is not None:
356 multi_seq_length = args.multi_seq_length
357 train_datasets, valid_datasets = [], []
358 for task in args.multi_task_data:
359 task = task.lower()
360 data_dir = os.path.join(args.data_dir, task_dirs[task])
361 train_datasets.append(
362 SuperGlueDataset(args, task, data_dir, multi_seq_length, "train", tokenizer, pattern_ensemble=True))
363 valid_datasets.append(
364 SuperGlueDataset(args, task, data_dir, multi_seq_length, "dev", tokenizer, pattern_ensemble=True))
365 train = MultiTaskDataset(args.multi_task_data, train_datasets)
366 valid = MultiTaskDataset(args.multi_task_data, valid_datasets)
367 world_size = torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
368 multi_batch_size = args.batch_size * world_size
369 if args.multi_batch_size is not None:
370 multi_batch_size = args.multi_batch_size * world_size
371 train = make_data_loader(train, tokenizer, multi_batch_size, args.train_iters, args, shuffle=True)
372 valid = make_data_loader(valid, tokenizer, multi_batch_size, args.train_iters, args, shuffle=True)
373 return train, valid
374
375
376def get_split(args):

Callers 1

mainFunction · 0.90

Calls 4

SuperGlueDatasetClass · 0.90
MultiTaskDatasetClass · 0.85
make_data_loaderFunction · 0.85
appendMethod · 0.80

Tested by

no test coverage detected