Masks the user turns of a dialogue from the loss
(tokenizer, dialogue_template, labels)
| 230 | |
| 231 | |
| 232 | def mask_user_labels(tokenizer, dialogue_template, labels): |
| 233 | """Masks the user turns of a dialogue from the loss""" |
| 234 | user_token_id = tokenizer.convert_tokens_to_ids(dialogue_template.user_token) |
| 235 | assistant_token_id = tokenizer.convert_tokens_to_ids(dialogue_template.assistant_token) |
| 236 | for idx, label_id in enumerate(labels): |
| 237 | if label_id == user_token_id: |
| 238 | current_idx = idx |
| 239 | while labels[current_idx] != assistant_token_id and current_idx < len(labels): |
| 240 | labels[current_idx] = IGNORE_INDEX |
| 241 | current_idx += 1 |