(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False
)
| 412 | |
| 413 | |
| 414 | def preprocess_v1( |
| 415 | sources, |
| 416 | tokenizer: transformers.PreTrainedTokenizer, |
| 417 | has_image: bool = False |
| 418 | ) -> Dict: |
| 419 | conv = conversation_lib.default_conversation.copy() |
| 420 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
| 421 | |
| 422 | # Apply prompt templates |
| 423 | conversations = [] |
| 424 | for i, source in enumerate(sources): |
| 425 | if roles[source[0]["from"]] != conv.roles[0]: |
| 426 | # Skip the first one if it is not from human |
| 427 | source = source[1:] |
| 428 | |
| 429 | conv.messages = [] |
| 430 | for j, sentence in enumerate(source): |
| 431 | role = roles[sentence["from"]] |
| 432 | assert role == conv.roles[j % 2], f"{i}" |
| 433 | conv.append_message(role, sentence["value"]) |
| 434 | conversations.append(conv.get_prompt()) |
| 435 | |
| 436 | # Tokenize conversations |
| 437 | |
| 438 | if has_image: |
| 439 | input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) |
| 440 | else: |
| 441 | input_ids = tokenizer( |
| 442 | conversations, |
| 443 | return_tensors="pt", |
| 444 | padding="longest", |
| 445 | max_length=tokenizer.model_max_length, |
| 446 | truncation=True, |
| 447 | ).input_ids |
| 448 | |
| 449 | targets = input_ids.clone() |
| 450 | |
| 451 | assert conv.sep_style == conversation_lib.SeparatorStyle.TWO |
| 452 | |
| 453 | # Mask targets |
| 454 | sep = conv.sep + conv.roles[1] + ": " |
| 455 | for conversation, target in zip(conversations, targets): |
| 456 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
| 457 | |
| 458 | rounds = conversation.split(conv.sep2) |
| 459 | cur_len = 1 |
| 460 | target[:cur_len] = IGNORE_INDEX |
| 461 | for i, rou in enumerate(rounds): |
| 462 | if rou == "": |
| 463 | break |
| 464 | |
| 465 | parts = rou.split(sep) |
| 466 | if len(parts) != 2: |
| 467 | break |
| 468 | parts[0] += sep |
| 469 | |
| 470 | if has_image: |
| 471 | round_len = len(tokenizer_image_token(rou, tokenizer)) |
no test coverage detected