(args, tokenizer)
| 346 | |
| 347 | |
| 348 | def build_multi_task_dataset(args, tokenizer): |
| 349 | task_dirs = {"mnli": "MNLI", "cola": "CoLA", "mrpc": "MRPC", "qnli": "QNLI", "qqp": "QQP", "sst2": "SST-2", |
| 350 | "agnews": "Agnews", "yelp-polarity": "yelp_review_polarity_csv", "yelp-full": "yelp_review_full_csv", |
| 351 | "yahoo": "Yahoo", "squad": "SQuAD", "race": "RACE"} |
| 352 | train, valid = None, None |
| 353 | if mpu.get_model_parallel_rank() == 0: |
| 354 | multi_seq_length = args.seq_length |
| 355 | if args.multi_seq_length is not None: |
| 356 | multi_seq_length = args.multi_seq_length |
| 357 | train_datasets, valid_datasets = [], [] |
| 358 | for task in args.multi_task_data: |
| 359 | task = task.lower() |
| 360 | data_dir = os.path.join(args.data_dir, task_dirs[task]) |
| 361 | train_datasets.append( |
| 362 | SuperGlueDataset(args, task, data_dir, multi_seq_length, "train", tokenizer, pattern_ensemble=True)) |
| 363 | valid_datasets.append( |
| 364 | SuperGlueDataset(args, task, data_dir, multi_seq_length, "dev", tokenizer, pattern_ensemble=True)) |
| 365 | train = MultiTaskDataset(args.multi_task_data, train_datasets) |
| 366 | valid = MultiTaskDataset(args.multi_task_data, valid_datasets) |
| 367 | world_size = torch.distributed.get_world_size(group=mpu.get_data_parallel_group()) |
| 368 | multi_batch_size = args.batch_size * world_size |
| 369 | if args.multi_batch_size is not None: |
| 370 | multi_batch_size = args.multi_batch_size * world_size |
| 371 | train = make_data_loader(train, tokenizer, multi_batch_size, args.train_iters, args, shuffle=True) |
| 372 | valid = make_data_loader(valid, tokenizer, multi_batch_size, args.train_iters, args, shuffle=True) |
| 373 | return train, valid |
| 374 | |
| 375 | |
| 376 | def get_split(args): |
no test coverage detected