(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
)
| 586 | |
| 587 | |
| 588 | def preprocess_plain( |
| 589 | sources: Sequence[str], |
| 590 | tokenizer: transformers.PreTrainedTokenizer, |
| 591 | ) -> Dict: |
| 592 | # add end signal and concatenate together |
| 593 | conversations = [] |
| 594 | for source in sources: |
| 595 | assert len(source) == 2 |
| 596 | assert DEFAULT_IMAGE_TOKEN in source[0]['value'] |
| 597 | source[0]['value'] = DEFAULT_IMAGE_TOKEN |
| 598 | conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep |
| 599 | conversations.append(conversation) |
| 600 | # tokenize conversations |
| 601 | input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] |
| 602 | targets = copy.deepcopy(input_ids) |
| 603 | for target, source in zip(targets, sources): |
| 604 | tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer)) |
| 605 | target[:tokenized_len] = IGNORE_INDEX |
| 606 | |
| 607 | return dict(input_ids=input_ids, labels=targets) |
| 608 | |
| 609 | |
| 610 | def preprocess( |
no test coverage detected