(args, index, finetune=False, shuffle=True)
| 111 | return (not args.no_cuda and dist.get_rank() == 0) or (args.no_cuda and args.local_rank == -1) |
| 112 | |
| 113 | def get_train_dataset(args, index, finetune=False, shuffle=True): |
| 114 | assert not finetune, "finetune not supported" |
| 115 | i = 0 |
| 116 | dataloaders = {} |
| 117 | datalengths = [] |
| 118 | batchs_per_dataset = [] |
| 119 | batch_mapping = {} |
| 120 | |
| 121 | config = args.config |
| 122 | dataset_paths = config["data"]["datasets"] |
| 123 | dataset_flags = config["data"]["flags"] |
| 124 | |
| 125 | # Pretraining dataset |
| 126 | if dataset_flags.get("pretrain_dataset", False): |
| 127 | pretrain_type = dataset_flags.get("pretrain_type") |
| 128 | |
| 129 | if pretrain_type == "wiki_bc": |
| 130 | # Load Wiki Dataset |
| 131 | wiki_pretrain_dataset = PreTrainingDataset( |
| 132 | args.tokenizer, |
| 133 | os.path.join(args.data_path_prefix, dataset_paths['wiki_pretrain_dataset']), |
| 134 | args.logger, |
| 135 | args.max_seq_length, |
| 136 | index, |
| 137 | PretrainDataType.NUMPY, |
| 138 | args.max_predictions_per_seq) |
| 139 | datalengths.append(len(wiki_pretrain_dataset)) |
| 140 | dataloaders[i] = get_dataloader(args, wiki_pretrain_dataset) |
| 141 | batch_mapping[i] = PretrainBatch |
| 142 | batchs_per_dataset.append( |
| 143 | get_effective_batch(args, len(wiki_pretrain_dataset))) |
| 144 | i += 1 |
| 145 | |
| 146 | bc_pretrain_dataset = PreTrainingDataset( |
| 147 | args.tokenizer, |
| 148 | os.path.join(args.data_path_prefix, dataset_paths['bc_pretrain_dataset']), |
| 149 | args.logger, |
| 150 | args.max_seq_length, |
| 151 | index, |
| 152 | PretrainDataType.NUMPY, |
| 153 | args.max_predictions_per_seq |
| 154 | ) |
| 155 | datalengths.append(len(bc_pretrain_dataset)) |
| 156 | dataloaders[i] = get_dataloader(args, bc_pretrain_dataset) |
| 157 | batch_mapping[i] = PretrainBatch |
| 158 | batchs_per_dataset.append( |
| 159 | get_effective_batch(args, len(bc_pretrain_dataset))) |
| 160 | i += 1 |
| 161 | |
| 162 | dataset_batches = [] |
| 163 | for i, batch_count in enumerate(batchs_per_dataset): |
| 164 | dataset_batches.extend([i] * batch_count) |
| 165 | |
| 166 | # shuffle |
| 167 | if shuffle: |
| 168 | random.shuffle(dataset_batches) |
| 169 | |
| 170 | dataset_picker = [] |
no test coverage detected