MCPcopy
hub / github.com/XPixelGroup/DiffBIR / preprocess_v1

Function preprocess_v1

llava/train/train.py:414–497  ·  view source on GitHub ↗
(
    sources,
    tokenizer: transformers.PreTrainedTokenizer,
    has_image: bool = False
)

Source from the content-addressed store, hash-verified

412
413
414def preprocess_v1(
415 sources,
416 tokenizer: transformers.PreTrainedTokenizer,
417 has_image: bool = False
418) -> Dict:
419 conv = conversation_lib.default_conversation.copy()
420 roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
421
422 # Apply prompt templates
423 conversations = []
424 for i, source in enumerate(sources):
425 if roles[source[0]["from"]] != conv.roles[0]:
426 # Skip the first one if it is not from human
427 source = source[1:]
428
429 conv.messages = []
430 for j, sentence in enumerate(source):
431 role = roles[sentence["from"]]
432 assert role == conv.roles[j % 2], f"{i}"
433 conv.append_message(role, sentence["value"])
434 conversations.append(conv.get_prompt())
435
436 # Tokenize conversations
437
438 if has_image:
439 input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
440 else:
441 input_ids = tokenizer(
442 conversations,
443 return_tensors="pt",
444 padding="longest",
445 max_length=tokenizer.model_max_length,
446 truncation=True,
447 ).input_ids
448
449 targets = input_ids.clone()
450
451 assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
452
453 # Mask targets
454 sep = conv.sep + conv.roles[1] + ": "
455 for conversation, target in zip(conversations, targets):
456 total_len = int(target.ne(tokenizer.pad_token_id).sum())
457
458 rounds = conversation.split(conv.sep2)
459 cur_len = 1
460 target[:cur_len] = IGNORE_INDEX
461 for i, rou in enumerate(rounds):
462 if rou == "":
463 break
464
465 parts = rou.split(sep)
466 if len(parts) != 2:
467 break
468 parts[0] += sep
469
470 if has_image:
471 round_len = len(tokenizer_image_token(rou, tokenizer))

Callers 1

preprocessFunction · 0.85

Calls 4

tokenizer_image_tokenFunction · 0.90
copyMethod · 0.80
append_messageMethod · 0.80
get_promptMethod · 0.80

Tested by

no test coverage detected