Given a list of sources, each is a conversation list. This transform: 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 2. Concatenate conversations together; 3. Tokenize the concatenated conversation; 4. Make a deepcopy as the target. Mask human words w
(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
)
| 443 | |
| 444 | |
| 445 | def preprocess( |
| 446 | sources: Sequence[str], |
| 447 | tokenizer: transformers.PreTrainedTokenizer, |
| 448 | ) -> Dict: |
| 449 | """ |
| 450 | Given a list of sources, each is a conversation list. This transform: |
| 451 | 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; |
| 452 | 2. Concatenate conversations together; |
| 453 | 3. Tokenize the concatenated conversation; |
| 454 | 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. |
| 455 | """ |
| 456 | if conversation_lib.default_conversation.version == "v1": |
| 457 | return preprocess_v1(sources, tokenizer) |
| 458 | if conversation_lib.default_conversation.version == "mpt": |
| 459 | return preprocess_mpt(sources, tokenizer) |
| 460 | # add end signal and concatenate together |
| 461 | conversations = [] |
| 462 | for source in sources: |
| 463 | header = f"{conversation_lib.default_conversation.system}\n\n" |
| 464 | conversation = _add_speaker_and_signal(header, source) |
| 465 | conversations.append(conversation) |
| 466 | # tokenize conversations |
| 467 | conversations_tokenized = _tokenize_fn(conversations, tokenizer) |
| 468 | input_ids = conversations_tokenized["input_ids"] |
| 469 | targets = copy.deepcopy(input_ids) |
| 470 | for target, source in zip(targets, sources): |
| 471 | tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], |
| 472 | tokenizer)["input_ids_lens"] |
| 473 | speakers = [sentence["from"] for sentence in source] |
| 474 | _mask_targets(target, tokenized_lens, speakers) |
| 475 | |
| 476 | return dict(input_ids=input_ids, labels=targets) |
| 477 | |
| 478 | |
| 479 | class SupervisedDataset(Dataset): |
no test coverage detected