MCPcopy
hub / github.com/XPixelGroup/DiffBIR / preprocess_llama_2

Function preprocess_llama_2

llava/train/train.py:332–411  ·  view source on GitHub ↗
(
    sources,
    tokenizer: transformers.PreTrainedTokenizer,
    has_image: bool = False
)

Source from the content-addressed store, hash-verified

330
331
332def preprocess_llama_2(
333 sources,
334 tokenizer: transformers.PreTrainedTokenizer,
335 has_image: bool = False
336) -> Dict:
337 conv = conversation_lib.default_conversation.copy()
338 roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
339
340 # Apply prompt templates
341 conversations = []
342 for i, source in enumerate(sources):
343 if roles[source[0]["from"]] != conv.roles[0]:
344 # Skip the first one if it is not from human
345 source = source[1:]
346
347 conv.messages = []
348 for j, sentence in enumerate(source):
349 role = roles[sentence["from"]]
350 assert role == conv.roles[j % 2], f"{i}"
351 conv.append_message(role, sentence["value"])
352 conversations.append(conv.get_prompt())
353
354 # Tokenize conversations
355
356 if has_image:
357 input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
358 else:
359 input_ids = tokenizer(
360 conversations,
361 return_tensors="pt",
362 padding="longest",
363 max_length=tokenizer.model_max_length,
364 truncation=True,
365 ).input_ids
366
367 targets = input_ids.clone()
368
369 assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
370
371 # Mask targets
372 sep = "[/INST] "
373 for conversation, target in zip(conversations, targets):
374 total_len = int(target.ne(tokenizer.pad_token_id).sum())
375
376 rounds = conversation.split(conv.sep2)
377 cur_len = 1
378 target[:cur_len] = IGNORE_INDEX
379 for i, rou in enumerate(rounds):
380 if rou == "":
381 break
382
383 parts = rou.split(sep)
384 if len(parts) != 2:
385 break
386 parts[0] += sep
387
388 if has_image:
389 round_len = len(tokenizer_image_token(rou, tokenizer))

Callers 1

preprocessFunction · 0.85

Calls 4

tokenizer_image_tokenFunction · 0.90
copyMethod · 0.80
append_messageMethod · 0.80
get_promptMethod · 0.80

Tested by

no test coverage detected