| 36 | |
| 37 | |
| 38 | class PromptDataset(Dataset): |
| 39 | |
| 40 | def __init__(self, chosen_dataset) -> None: |
| 41 | super().__init__() |
| 42 | self.dataset = chosen_dataset |
| 43 | |
| 44 | def __len__(self): |
| 45 | length = len(self.dataset) |
| 46 | return length |
| 47 | |
| 48 | def __getitem__(self, idx): |
| 49 | return { |
| 50 | "input_ids": self.dataset[idx]["input_ids"], |
| 51 | "attention_mask": self.dataset[idx]["attention_mask"], |
| 52 | "labels": self.dataset[idx]["input_ids"] |
| 53 | } |
| 54 | |
| 55 | def get_weight_data(current_dataset, dataset_weight): |
| 56 | dataset = [] |