(
doc_tokens, features, start_logits, end_logits, n_best_size, max_answer_length
)
| 300 | |
| 301 | |
| 302 | def get_predictions( |
| 303 | doc_tokens, features, start_logits, end_logits, n_best_size, max_answer_length |
| 304 | ): |
| 305 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name |
| 306 | "PrelimPrediction", |
| 307 | ["feature_index", "start_index", "end_index", "start_logit", "end_logit"], |
| 308 | ) |
| 309 | |
| 310 | prediction = "" |
| 311 | scores_diff_json = 0.0 |
| 312 | |
| 313 | prelim_predictions = [] |
| 314 | # keep track of the minimum score of null start+end of position 0 |
| 315 | score_null = 1000000 # large and positive |
| 316 | min_null_feature_index = 0 # the paragraph slice with min mull score |
| 317 | null_start_logit = 0 # the start logit at the slice with min null score |
| 318 | null_end_logit = 0 # the end logit at the slice with min null score |
| 319 | |
| 320 | start_indexes = _get_best_indexes(start_logits, n_best_size) |
| 321 | end_indexes = _get_best_indexes(end_logits, n_best_size) |
| 322 | |
| 323 | # if we could have irrelevant answers, get the min score of irrelevant |
| 324 | version_2_with_negative = True |
| 325 | if version_2_with_negative: |
| 326 | feature_null_score = start_logits[0] + end_logits[0] |
| 327 | if feature_null_score < score_null: |
| 328 | score_null = feature_null_score |
| 329 | min_null_feature_index = 0 |
| 330 | null_start_logit = start_logits[0] |
| 331 | null_end_logit = end_logits[0] |
| 332 | |
| 333 | for start_index in start_indexes: |
| 334 | for end_index in end_indexes: |
| 335 | # We could hypothetically create invalid predictions, e.g., predict |
| 336 | # that the start of the span is in the question. We throw out all |
| 337 | # invalid predictions. |
| 338 | if start_index >= len(features["tokens"]): |
| 339 | continue |
| 340 | if end_index >= len(features["tokens"]): |
| 341 | continue |
| 342 | if start_index not in features["token_to_orig_map"]: |
| 343 | continue |
| 344 | if end_index not in features["token_to_orig_map"]: |
| 345 | continue |
| 346 | if not features["token_is_max_context"].get(start_index, False): |
| 347 | continue |
| 348 | if end_index < start_index: |
| 349 | continue |
| 350 | length = end_index - start_index + 1 |
| 351 | if length > max_answer_length: |
| 352 | continue |
| 353 | prelim_predictions.append( |
| 354 | _PrelimPrediction( |
| 355 | feature_index=0, |
| 356 | start_index=start_index, |
| 357 | end_index=end_index, |
| 358 | start_logit=start_logits[start_index], |
| 359 | end_logit=end_logits[end_index], |
nothing calls this directly
no test coverage detected