(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False
)
| 330 | |
| 331 | |
| 332 | def preprocess_llama_2( |
| 333 | sources, |
| 334 | tokenizer: transformers.PreTrainedTokenizer, |
| 335 | has_image: bool = False |
| 336 | ) -> Dict: |
| 337 | conv = conversation_lib.default_conversation.copy() |
| 338 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
| 339 | |
| 340 | # Apply prompt templates |
| 341 | conversations = [] |
| 342 | for i, source in enumerate(sources): |
| 343 | if roles[source[0]["from"]] != conv.roles[0]: |
| 344 | # Skip the first one if it is not from human |
| 345 | source = source[1:] |
| 346 | |
| 347 | conv.messages = [] |
| 348 | for j, sentence in enumerate(source): |
| 349 | role = roles[sentence["from"]] |
| 350 | assert role == conv.roles[j % 2], f"{i}" |
| 351 | conv.append_message(role, sentence["value"]) |
| 352 | conversations.append(conv.get_prompt()) |
| 353 | |
| 354 | # Tokenize conversations |
| 355 | |
| 356 | if has_image: |
| 357 | input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) |
| 358 | else: |
| 359 | input_ids = tokenizer( |
| 360 | conversations, |
| 361 | return_tensors="pt", |
| 362 | padding="longest", |
| 363 | max_length=tokenizer.model_max_length, |
| 364 | truncation=True, |
| 365 | ).input_ids |
| 366 | |
| 367 | targets = input_ids.clone() |
| 368 | |
| 369 | assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 |
| 370 | |
| 371 | # Mask targets |
| 372 | sep = "[/INST] " |
| 373 | for conversation, target in zip(conversations, targets): |
| 374 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
| 375 | |
| 376 | rounds = conversation.split(conv.sep2) |
| 377 | cur_len = 1 |
| 378 | target[:cur_len] = IGNORE_INDEX |
| 379 | for i, rou in enumerate(rounds): |
| 380 | if rou == "": |
| 381 | break |
| 382 | |
| 383 | parts = rou.split(sep) |
| 384 | if len(parts) != 2: |
| 385 | break |
| 386 | parts[0] += sep |
| 387 | |
| 388 | if has_image: |
| 389 | round_len = len(tokenizer_image_token(rou, tokenizer)) |
no test coverage detected