| 371 | ) |
| 372 | |
| 373 | def preprocess_mpt( |
| 374 | sources, |
| 375 | tokenizer: transformers.PreTrainedTokenizer, |
| 376 | ) -> Dict: |
| 377 | conv = conversation_lib.default_conversation.copy() |
| 378 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
| 379 | |
| 380 | # Apply prompt templates |
| 381 | conversations = [] |
| 382 | for i, source in enumerate(sources): |
| 383 | if roles[source[0]["from"]] != conv.roles[0]: |
| 384 | # Skip the first one if it is not from human |
| 385 | source = source[1:] |
| 386 | |
| 387 | conv.messages = [] |
| 388 | for j, sentence in enumerate(source): |
| 389 | role = roles[sentence["from"]] |
| 390 | assert role == conv.roles[j % 2], f"{i}" |
| 391 | conv.append_message(role, sentence["value"]) |
| 392 | conversations.append(conv.get_prompt()) |
| 393 | |
| 394 | # Tokenize conversations |
| 395 | input_ids = tokenizer( |
| 396 | conversations, |
| 397 | return_tensors="pt", |
| 398 | padding="longest", |
| 399 | max_length=tokenizer.model_max_length, |
| 400 | truncation=True, |
| 401 | ).input_ids |
| 402 | targets = input_ids.clone() |
| 403 | assert conv.sep_style == conversation_lib.SeparatorStyle.MPT |
| 404 | |
| 405 | # Mask targets |
| 406 | sep = conv.sep + conv.roles[1] |
| 407 | for conversation, target in zip(conversations, targets): |
| 408 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
| 409 | |
| 410 | rounds = conversation.split(conv.sep) |
| 411 | re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt |
| 412 | for conv_idx in range(3, len(rounds), 2): |
| 413 | re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt |
| 414 | cur_len = 0 |
| 415 | target[:cur_len] = IGNORE_INDEX |
| 416 | for i, rou in enumerate(re_rounds): |
| 417 | if rou == "": |
| 418 | break |
| 419 | |
| 420 | parts = rou.split(sep) |
| 421 | if len(parts) != 2: |
| 422 | break |
| 423 | parts[0] += sep |
| 424 | round_len = len(tokenizer(rou).input_ids) + len(tokenizer(conv.sep).input_ids) |
| 425 | instruction_len = len(tokenizer(parts[0]).input_ids) |
| 426 | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
| 427 | |
| 428 | cur_len += round_len |
| 429 | target[cur_len:] = IGNORE_INDEX |
| 430 | |