MCPcopy Index your code
hub / github.com/XPixelGroup/DiffBIR / DataCollatorForSupervisedDataset

Class DataCollatorForSupervisedDataset

llava/train/train.py:743–773  ·  view source on GitHub ↗

Collate examples for supervised fine-tuning.

Source from the content-addressed store, hash-verified

741
742@dataclass
743class DataCollatorForSupervisedDataset(object):
744 """Collate examples for supervised fine-tuning."""
745
746 tokenizer: transformers.PreTrainedTokenizer
747
748 def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
749 input_ids, labels = tuple([instance[key] for instance in instances]
750 for key in ("input_ids", "labels"))
751 input_ids = torch.nn.utils.rnn.pad_sequence(
752 input_ids,
753 batch_first=True,
754 padding_value=self.tokenizer.pad_token_id)
755 labels = torch.nn.utils.rnn.pad_sequence(labels,
756 batch_first=True,
757 padding_value=IGNORE_INDEX)
758 input_ids = input_ids[:, :self.tokenizer.model_max_length]
759 labels = labels[:, :self.tokenizer.model_max_length]
760 batch = dict(
761 input_ids=input_ids,
762 labels=labels,
763 attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
764 )
765
766 if 'image' in instances[0]:
767 images = [instance['image'] for instance in instances]
768 if all(x is not None and x.shape == images[0].shape for x in images):
769 batch['images'] = torch.stack(images)
770 else:
771 batch['images'] = images
772
773 return batch
774
775
776def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,

Callers 1

Calls

no outgoing calls

Tested by

no test coverage detected