MCPcopy Index your code
hub / github.com/THUDM/GLM / MultiTaskDataset

Class MultiTaskDataset

configure_data.py:35–92  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

33
34
35class MultiTaskDataset(torch.utils.data.Dataset):
36 def __init__(self, tasks, datasets, reweight=True, temperature=0.8, max_limit=200000):
37 super(MultiTaskDataset, self).__init__()
38 self.tasks = tasks
39 self.datasets = datasets
40 self.reweight = reweight
41 self.temperature = temperature
42 self.lens = [len(dataset) for dataset in datasets]
43 self.weights = np.array([min(l, max_limit) ** temperature for l in self.lens])
44 self.total_len = sum(self.lens)
45 self.cumulative_lens = list(accumulate(self.lens))
46 if self.reweight:
47 print_rank_0(list(zip(self.tasks, self.lens, self.weights)))
48 else:
49 print_rank_0(list(zip(self.tasks, self.lens)))
50 self.weights /= self.weights.sum()
51
52 def __len__(self):
53 return self.total_len * 1000
54
55 @staticmethod
56 def pet_wrapper(data):
57 text = data['text']
58 loss_mask = data['logit_mask']
59 target = data['target']
60 attention_mask = data['mask']
61 position_id = data['position']
62 label = data['label']
63 if len(text.shape) == 2:
64 text = text[label]
65 loss_mask = loss_mask[label]
66 target = target[label]
67 attention_mask = attention_mask[label]
68 position_id = position_id[label]
69 else:
70 target = target[label]
71 if not target.shape:
72 target = target.repeat(len(text))
73 return {'text': text, 'target': target, 'loss_mask': loss_mask, 'position_id': position_id,
74 'attention_mask': attention_mask}
75
76 def __getitem__(self, idx):
77 if self.reweight:
78 rng = random.Random(idx)
79 rng = np.random.RandomState(seed=[rng.randint(0, 2 ** 32 - 1) for _ in range(16)])
80 dataset_idx = rng.choice(np.arange(len(self.datasets)), p=self.weights)
81 dataset = self.datasets[dataset_idx]
82 sample_idx = rng.choice(np.arange(len(dataset)))
83 item = self.datasets[dataset_idx][sample_idx]
84 else:
85 dataset_idx = bisect_right(self.cumulative_lens, idx)
86 if dataset_idx == 0:
87 sample_idx = idx
88 else:
89 sample_idx = idx - self.cumulative_lens[dataset_idx - 1]
90 item = self.datasets[dataset_idx][sample_idx]
91 item = self.pet_wrapper(item)
92 return item

Callers 1

build_multi_task_datasetFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected