| 33 | |
| 34 | |
| 35 | class MultiTaskDataset(torch.utils.data.Dataset): |
| 36 | def __init__(self, tasks, datasets, reweight=True, temperature=0.8, max_limit=200000): |
| 37 | super(MultiTaskDataset, self).__init__() |
| 38 | self.tasks = tasks |
| 39 | self.datasets = datasets |
| 40 | self.reweight = reweight |
| 41 | self.temperature = temperature |
| 42 | self.lens = [len(dataset) for dataset in datasets] |
| 43 | self.weights = np.array([min(l, max_limit) ** temperature for l in self.lens]) |
| 44 | self.total_len = sum(self.lens) |
| 45 | self.cumulative_lens = list(accumulate(self.lens)) |
| 46 | if self.reweight: |
| 47 | print_rank_0(list(zip(self.tasks, self.lens, self.weights))) |
| 48 | else: |
| 49 | print_rank_0(list(zip(self.tasks, self.lens))) |
| 50 | self.weights /= self.weights.sum() |
| 51 | |
| 52 | def __len__(self): |
| 53 | return self.total_len * 1000 |
| 54 | |
| 55 | @staticmethod |
| 56 | def pet_wrapper(data): |
| 57 | text = data['text'] |
| 58 | loss_mask = data['logit_mask'] |
| 59 | target = data['target'] |
| 60 | attention_mask = data['mask'] |
| 61 | position_id = data['position'] |
| 62 | label = data['label'] |
| 63 | if len(text.shape) == 2: |
| 64 | text = text[label] |
| 65 | loss_mask = loss_mask[label] |
| 66 | target = target[label] |
| 67 | attention_mask = attention_mask[label] |
| 68 | position_id = position_id[label] |
| 69 | else: |
| 70 | target = target[label] |
| 71 | if not target.shape: |
| 72 | target = target.repeat(len(text)) |
| 73 | return {'text': text, 'target': target, 'loss_mask': loss_mask, 'position_id': position_id, |
| 74 | 'attention_mask': attention_mask} |
| 75 | |
| 76 | def __getitem__(self, idx): |
| 77 | if self.reweight: |
| 78 | rng = random.Random(idx) |
| 79 | rng = np.random.RandomState(seed=[rng.randint(0, 2 ** 32 - 1) for _ in range(16)]) |
| 80 | dataset_idx = rng.choice(np.arange(len(self.datasets)), p=self.weights) |
| 81 | dataset = self.datasets[dataset_idx] |
| 82 | sample_idx = rng.choice(np.arange(len(dataset))) |
| 83 | item = self.datasets[dataset_idx][sample_idx] |
| 84 | else: |
| 85 | dataset_idx = bisect_right(self.cumulative_lens, idx) |
| 86 | if dataset_idx == 0: |
| 87 | sample_idx = idx |
| 88 | else: |
| 89 | sample_idx = idx - self.cumulative_lens[dataset_idx - 1] |
| 90 | item = self.datasets[dataset_idx][sample_idx] |
| 91 | item = self.pet_wrapper(item) |
| 92 | return item |
no outgoing calls
no test coverage detected