(
sources,
tokenizer: transformers.PreTrainedTokenizer,
)
| 90 | |
| 91 | |
| 92 | def preprocess( |
| 93 | sources, |
| 94 | tokenizer: transformers.PreTrainedTokenizer, |
| 95 | ) -> Dict: |
| 96 | conv = get_conversation_template("vicuna") |
| 97 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
| 98 | |
| 99 | # Apply prompt templates |
| 100 | conversations = [] |
| 101 | for i, source in enumerate(sources): |
| 102 | if roles[source[0]["from"]] != conv.roles[0]: |
| 103 | # Skip the first one if it is not from human |
| 104 | source = source[1:] |
| 105 | |
| 106 | conv.messages = [] |
| 107 | for j, sentence in enumerate(source): |
| 108 | role = roles[sentence["from"]] |
| 109 | assert role == conv.roles[j % 2], f"{i}" |
| 110 | conv.append_message(role, sentence["value"]) |
| 111 | conversations.append(conv.get_prompt()) |
| 112 | |
| 113 | # Tokenize conversations |
| 114 | input_ids = tokenizer( |
| 115 | conversations, |
| 116 | return_tensors="pt", |
| 117 | padding="max_length", |
| 118 | max_length=tokenizer.model_max_length, |
| 119 | truncation=True, |
| 120 | ).input_ids |
| 121 | targets = input_ids.clone() |
| 122 | |
| 123 | assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO |
| 124 | |
| 125 | # Mask targets. Only compute loss on the assistant outputs. |
| 126 | sep = conv.sep + conv.roles[1] + ": " |
| 127 | for conversation, target in zip(conversations, targets): |
| 128 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
| 129 | |
| 130 | turns = conversation.split(conv.sep2) |
| 131 | cur_len = 1 |
| 132 | target[:cur_len] = IGNORE_TOKEN_ID |
| 133 | for i, turn in enumerate(turns): |
| 134 | if turn == "": |
| 135 | break |
| 136 | turn_len = len(tokenizer(turn).input_ids) |
| 137 | |
| 138 | parts = turn.split(sep) |
| 139 | if len(parts) != 2: |
| 140 | break |
| 141 | parts[0] += sep |
| 142 | # "-2" is hardcoded for the Llama tokenizer to make the offset correct. |
| 143 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
| 144 | |
| 145 | if i != 0 and not tokenizer.legacy: |
| 146 | # The legacy and non-legacy modes handle special tokens differently |
| 147 | instruction_len -= 1 |
| 148 | |
| 149 | # Ignore the user instructions |
no test coverage detected
searching dependent graphs…