| 35 | return prompt_content |
| 36 | |
| 37 | class PretrainDataset(Dataset): |
| 38 | def __init__(self, data_path, tokenizer, max_length=512): |
| 39 | super().__init__() |
| 40 | self.tokenizer = tokenizer |
| 41 | self.max_length = max_length |
| 42 | self.samples = load_dataset('json', data_files=data_path, split='train') |
| 43 | |
| 44 | def __len__(self): |
| 45 | return len(self.samples) |
| 46 | |
| 47 | def __getitem__(self, index): |
| 48 | sample = self.samples[index] |
| 49 | tokens = self.tokenizer(str(sample['text']), add_special_tokens=False, max_length=self.max_length - 2, truncation=True).input_ids |
| 50 | tokens = [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id] |
| 51 | input_ids = tokens + [self.tokenizer.pad_token_id] * (self.max_length - len(tokens)) |
| 52 | input_ids = torch.tensor(input_ids, dtype=torch.long) |
| 53 | labels = input_ids.clone() |
| 54 | labels[input_ids == self.tokenizer.pad_token_id] = -100 |
| 55 | return input_ids, labels |
| 56 | |
| 57 | |
| 58 | class SFTDataset(Dataset): |