(
sources,
tokenizer: transformers.PreTrainedTokenizer,
data_args,
)
| 98 | |
| 99 | |
| 100 | def preprocess( |
| 101 | sources, |
| 102 | tokenizer: transformers.PreTrainedTokenizer, |
| 103 | data_args, |
| 104 | ) -> Dict: |
| 105 | conv = get_conversation_template("yuan2") # wpf |
| 106 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
| 107 | |
| 108 | # Apply prompt templates |
| 109 | conversations = [] |
| 110 | for i, source in enumerate(sources): |
| 111 | if roles[source[0]["from"]] != conv.roles[0]: |
| 112 | # Skip the first one if it is not from human |
| 113 | source = source[1:] |
| 114 | |
| 115 | conv.messages = [] |
| 116 | for j, sentence in enumerate(source): |
| 117 | role = roles[sentence["from"]] |
| 118 | assert role == conv.roles[j % 2], f"{i}" |
| 119 | conv.append_message(role, sentence["value"]) |
| 120 | conversations.append(conv.get_prompt()) |
| 121 | if data_args.last_response_loss: |
| 122 | a = conversations[0].replace("<sep>", "<eod>") |
| 123 | a = right_replace(a, "<n>", "<sep>") |
| 124 | # a=right_replace(a,"<n>","\n",max=20) |
| 125 | conversations[0] = a |
| 126 | if data_args.split_example_loss: |
| 127 | a = conversations[0].replace("<sep>", "") |
| 128 | a = a.split("<n>") |
| 129 | for i in range(int(len(a) / 2)): |
| 130 | if i == 0: |
| 131 | conversations[i] = "" |
| 132 | if i != 0: |
| 133 | conversations.append("") |
| 134 | for j in range(i * 2): |
| 135 | conversations[i] = conversations[i] + a[j] + "<n>" |
| 136 | conversations[i] = ( |
| 137 | conversations[i] + a[i * 2] + "<sep>" + a[i * 2 + 1] + "<eod>" |
| 138 | ) |
| 139 | |
| 140 | if data_args.efficient_loss: |
| 141 | a = conversations[0].replace("<sep>", "<eod>") |
| 142 | conversations[0] = a |
| 143 | |
| 144 | print(conversations) |
| 145 | |
| 146 | # Tokenize conversations |
| 147 | input_ids = tokenizer( |
| 148 | conversations, |
| 149 | return_tensors="pt", |
| 150 | padding="max_length", |
| 151 | max_length=tokenizer.model_max_length, |
| 152 | truncation=True, |
| 153 | ).input_ids |
| 154 | targets = input_ids.clone() |
| 155 | |
| 156 | # assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO #wpf |
| 157 | # Mask targets. Only compute loss on the assistant outputs. |
no test coverage detected
searching dependent graphs…