| 480 | """Dataset for supervised fine-tuning.""" |
| 481 | |
| 482 | def __init__(self, data_path: str, |
| 483 | tokenizer: transformers.PreTrainedTokenizer): |
| 484 | super(SupervisedDataset, self).__init__() |
| 485 | logging.warning("Loading data...") |
| 486 | list_data_dict = json.load(open(data_path, "r")) |
| 487 | |
| 488 | logging.warning("Formatting inputs...") |
| 489 | sources = [example["conversations"] for example in list_data_dict] |
| 490 | data_dict = preprocess(sources, tokenizer) |
| 491 | |
| 492 | self.input_ids = data_dict["input_ids"] |
| 493 | self.labels = data_dict["labels"] |
| 494 | |
| 495 | def __len__(self): |
| 496 | return len(self.input_ids) |