MCPcopy Index your code
hub / github.com/XPixelGroup/DiffBIR / preprocess_mpt

Function preprocess_mpt

llava/train/train.py:500–585  ·  view source on GitHub ↗
(
    sources,
    tokenizer: transformers.PreTrainedTokenizer,
    has_image: bool = False
)

Source from the content-addressed store, hash-verified

498
499
500def preprocess_mpt(
501 sources,
502 tokenizer: transformers.PreTrainedTokenizer,
503 has_image: bool = False
504) -> Dict:
505 conv = conversation_lib.default_conversation.copy()
506 roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
507
508 # Apply prompt templates
509 conversations = []
510 for i, source in enumerate(sources):
511 if roles[source[0]["from"]] != conv.roles[0]:
512 # Skip the first one if it is not from human
513 source = source[1:]
514
515 conv.messages = []
516 for j, sentence in enumerate(source):
517 role = roles[sentence["from"]]
518 assert role == conv.roles[j % 2], f"{i}"
519 conv.append_message(role, sentence["value"])
520 conversations.append(conv.get_prompt())
521
522 # Tokenize conversations
523
524 if has_image:
525 input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
526 else:
527 input_ids = tokenizer(
528 conversations,
529 return_tensors="pt",
530 padding="longest",
531 max_length=tokenizer.model_max_length,
532 truncation=True,
533 ).input_ids
534
535 targets = input_ids.clone()
536 assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
537
538 # Mask targets
539 sep = conv.sep + conv.roles[1]
540 for conversation, target in zip(conversations, targets):
541 total_len = int(target.ne(tokenizer.pad_token_id).sum())
542
543 rounds = conversation.split(conv.sep)
544 re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
545 for conv_idx in range(3, len(rounds), 2):
546 re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt
547 cur_len = 0
548 target[:cur_len] = IGNORE_INDEX
549 for i, rou in enumerate(re_rounds):
550 if rou == "":
551 break
552
553 parts = rou.split(sep)
554 if len(parts) != 2:
555 break
556 parts[0] += sep
557

Callers 1

preprocessFunction · 0.85

Calls 4

tokenizer_image_tokenFunction · 0.90
copyMethod · 0.80
append_messageMethod · 0.80
get_promptMethod · 0.80

Tested by

no test coverage detected