(example, prediction, tokenizer)
| 216 | |
| 217 | |
| 218 | def squad_decode(example, prediction, tokenizer): |
| 219 | text = tokenizer.DecodeIds(prediction) |
| 220 | if text.replace(' ', '').lower() == 'n/a': |
| 221 | return text |
| 222 | context = example.meta['context'] |
| 223 | context_tokens = example.meta['context_tokens'] |
| 224 | token_to_char = example.meta['token_to_char'] |
| 225 | for i in range(len(context_tokens)): |
| 226 | if prediction == context_tokens[i:i + len(prediction)]: |
| 227 | s = token_to_char[i][0] |
| 228 | t = token_to_char[i + len(prediction) - 1][1] |
| 229 | return context[s:t] |
| 230 | text = squad_fix_tokenization(text) |
| 231 | return text |
| 232 | |
| 233 | |
| 234 | def process_batch(batch, args): |
no test coverage detected