(data_path, seed)
| 198 | process(0, train_datasets, tokenizer, max_seq_len, 1, num_samples, output_path, args) |
| 199 | |
| 200 | def get_dataset(data_path, seed): |
| 201 | files = glob.glob(os.path.join(data_path, "train_data*.pt")) |
| 202 | assert len(files) > 0, "There is no data here!" |
| 203 | train_datasets = [] |
| 204 | train_size = 0 |
| 205 | for file in files: |
| 206 | train_dataset = torch.load(file) |
| 207 | train_datasets.append(train_dataset) |
| 208 | train_size += len(train_dataset) |
| 209 | train_dataset = ConcatDataset(train_datasets) |
| 210 | shuffle_idx = get_shuffle_idx(seed, train_size) |
| 211 | train_dataset = Subset(train_dataset, shuffle_idx.tolist()) |
| 212 | return train_dataset |
no test coverage detected