Project the tokenized prediction back to the original text.
(pred_text, orig_text, do_lower_case)
| 191 | |
| 192 | |
| 193 | def get_final_text(pred_text, orig_text, do_lower_case): |
| 194 | """Project the tokenized prediction back to the original text.""" |
| 195 | |
| 196 | # When we created the data, we kept track of the alignment between original |
| 197 | # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So |
| 198 | # now `orig_text` contains the span of our original text corresponding to the |
| 199 | # span that we predicted. |
| 200 | # |
| 201 | # However, `orig_text` may contain extra characters that we don't want in |
| 202 | # our prediction. |
| 203 | # |
| 204 | # For example, let's say: |
| 205 | # pred_text = steve smith |
| 206 | # orig_text = Steve Smith's |
| 207 | # |
| 208 | # We don't want to return `orig_text` because it contains the extra "'s". |
| 209 | # |
| 210 | # We don't want to return `pred_text` because it's already been normalized |
| 211 | # (the SQuAD eval script also does punctuation stripping/lower casing but |
| 212 | # our tokenizer does additional normalization like stripping accent |
| 213 | # characters). |
| 214 | # |
| 215 | # What we really want to return is "Steve Smith". |
| 216 | # |
| 217 | # Therefore, we have to apply a semi-complicated alignment heuristic between |
| 218 | # `pred_text` and `orig_text` to get a character-to-character alignment. This |
| 219 | # can fail in certain cases in which case we just return `orig_text`. |
| 220 | |
| 221 | def _strip_spaces(text): |
| 222 | ns_chars = [] |
| 223 | ns_to_s_map = collections.OrderedDict() |
| 224 | for i, c in enumerate(text): |
| 225 | if c == " ": |
| 226 | continue |
| 227 | ns_to_s_map[len(ns_chars)] = i |
| 228 | ns_chars.append(c) |
| 229 | ns_text = "".join(ns_chars) |
| 230 | return (ns_text, ns_to_s_map) |
| 231 | |
| 232 | # We first tokenize `orig_text`, strip whitespace from the result |
| 233 | # and `pred_text`, and check if they are the same length. If they are |
| 234 | # NOT the same length, the heuristic has failed. If they are the same |
| 235 | # length, we assume the characters are one-to-one aligned. |
| 236 | tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case) |
| 237 | |
| 238 | tok_text = " ".join(tokenizer.tokenize(orig_text)) |
| 239 | |
| 240 | start_position = tok_text.find(pred_text) |
| 241 | if start_position == -1: |
| 242 | return orig_text |
| 243 | end_position = start_position + len(pred_text) - 1 |
| 244 | |
| 245 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) |
| 246 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) |
| 247 | |
| 248 | if len(orig_ns_text) != len(tok_ns_text): |
| 249 | return orig_text |
| 250 |
no test coverage detected