Collate examples for supervised fine-tuning.
| 741 | |
| 742 | @dataclass |
| 743 | class DataCollatorForSupervisedDataset(object): |
| 744 | """Collate examples for supervised fine-tuning.""" |
| 745 | |
| 746 | tokenizer: transformers.PreTrainedTokenizer |
| 747 | |
| 748 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
| 749 | input_ids, labels = tuple([instance[key] for instance in instances] |
| 750 | for key in ("input_ids", "labels")) |
| 751 | input_ids = torch.nn.utils.rnn.pad_sequence( |
| 752 | input_ids, |
| 753 | batch_first=True, |
| 754 | padding_value=self.tokenizer.pad_token_id) |
| 755 | labels = torch.nn.utils.rnn.pad_sequence(labels, |
| 756 | batch_first=True, |
| 757 | padding_value=IGNORE_INDEX) |
| 758 | input_ids = input_ids[:, :self.tokenizer.model_max_length] |
| 759 | labels = labels[:, :self.tokenizer.model_max_length] |
| 760 | batch = dict( |
| 761 | input_ids=input_ids, |
| 762 | labels=labels, |
| 763 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), |
| 764 | ) |
| 765 | |
| 766 | if 'image' in instances[0]: |
| 767 | images = [instance['image'] for instance in instances] |
| 768 | if all(x is not None and x.shape == images[0].shape for x in images): |
| 769 | batch['images'] = torch.stack(images) |
| 770 | else: |
| 771 | batch['images'] = images |
| 772 | |
| 773 | return batch |
| 774 | |
| 775 | |
| 776 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, |
no outgoing calls
no test coverage detected