(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False
)
| 498 | |
| 499 | |
| 500 | def preprocess_mpt( |
| 501 | sources, |
| 502 | tokenizer: transformers.PreTrainedTokenizer, |
| 503 | has_image: bool = False |
| 504 | ) -> Dict: |
| 505 | conv = conversation_lib.default_conversation.copy() |
| 506 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
| 507 | |
| 508 | # Apply prompt templates |
| 509 | conversations = [] |
| 510 | for i, source in enumerate(sources): |
| 511 | if roles[source[0]["from"]] != conv.roles[0]: |
| 512 | # Skip the first one if it is not from human |
| 513 | source = source[1:] |
| 514 | |
| 515 | conv.messages = [] |
| 516 | for j, sentence in enumerate(source): |
| 517 | role = roles[sentence["from"]] |
| 518 | assert role == conv.roles[j % 2], f"{i}" |
| 519 | conv.append_message(role, sentence["value"]) |
| 520 | conversations.append(conv.get_prompt()) |
| 521 | |
| 522 | # Tokenize conversations |
| 523 | |
| 524 | if has_image: |
| 525 | input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) |
| 526 | else: |
| 527 | input_ids = tokenizer( |
| 528 | conversations, |
| 529 | return_tensors="pt", |
| 530 | padding="longest", |
| 531 | max_length=tokenizer.model_max_length, |
| 532 | truncation=True, |
| 533 | ).input_ids |
| 534 | |
| 535 | targets = input_ids.clone() |
| 536 | assert conv.sep_style == conversation_lib.SeparatorStyle.MPT |
| 537 | |
| 538 | # Mask targets |
| 539 | sep = conv.sep + conv.roles[1] |
| 540 | for conversation, target in zip(conversations, targets): |
| 541 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
| 542 | |
| 543 | rounds = conversation.split(conv.sep) |
| 544 | re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt |
| 545 | for conv_idx in range(3, len(rounds), 2): |
| 546 | re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt |
| 547 | cur_len = 0 |
| 548 | target[:cur_len] = IGNORE_INDEX |
| 549 | for i, rou in enumerate(re_rounds): |
| 550 | if rou == "": |
| 551 | break |
| 552 | |
| 553 | parts = rou.split(sep) |
| 554 | if len(parts) != 2: |
| 555 | break |
| 556 | parts[0] += sep |
| 557 |
no test coverage detected