Given a list of sources, each is a conversation list. This transform: 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 2. Concatenate conversations together; 3. Tokenize the concatenated conversation; 4. Make a deepcopy as the target. Mask human words w
(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False
)
| 608 | |
| 609 | |
| 610 | def preprocess( |
| 611 | sources: Sequence[str], |
| 612 | tokenizer: transformers.PreTrainedTokenizer, |
| 613 | has_image: bool = False |
| 614 | ) -> Dict: |
| 615 | """ |
| 616 | Given a list of sources, each is a conversation list. This transform: |
| 617 | 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; |
| 618 | 2. Concatenate conversations together; |
| 619 | 3. Tokenize the concatenated conversation; |
| 620 | 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. |
| 621 | """ |
| 622 | if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: |
| 623 | return preprocess_plain(sources, tokenizer) |
| 624 | if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: |
| 625 | return preprocess_llama_2(sources, tokenizer, has_image=has_image) |
| 626 | if conversation_lib.default_conversation.version.startswith("v1"): |
| 627 | return preprocess_v1(sources, tokenizer, has_image=has_image) |
| 628 | if conversation_lib.default_conversation.version == "mpt": |
| 629 | return preprocess_mpt(sources, tokenizer, has_image=has_image) |
| 630 | # add end signal and concatenate together |
| 631 | conversations = [] |
| 632 | for source in sources: |
| 633 | header = f"{conversation_lib.default_conversation.system}\n\n" |
| 634 | conversation = _add_speaker_and_signal(header, source) |
| 635 | conversations.append(conversation) |
| 636 | # tokenize conversations |
| 637 | def get_tokenize_len(prompts): |
| 638 | return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] |
| 639 | |
| 640 | if has_image: |
| 641 | input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] |
| 642 | else: |
| 643 | conversations_tokenized = _tokenize_fn(conversations, tokenizer) |
| 644 | input_ids = conversations_tokenized["input_ids"] |
| 645 | |
| 646 | targets = copy.deepcopy(input_ids) |
| 647 | for target, source in zip(targets, sources): |
| 648 | if has_image: |
| 649 | tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) |
| 650 | else: |
| 651 | tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] |
| 652 | speakers = [sentence["from"] for sentence in source] |
| 653 | _mask_targets(target, tokenized_lens, speakers) |
| 654 | |
| 655 | return dict(input_ids=input_ids, labels=targets) |
| 656 | |
| 657 | |
| 658 | class LazySupervisedDataset(Dataset): |
no test coverage detected