Get prediction results and evaluate them to hard drive.
(strategy,
input_meta_data,
tokenizer,
bert_config,
squad_lib,
init_checkpoint=None)
| 388 | |
| 389 | |
| 390 | def predict_squad(strategy, |
| 391 | input_meta_data, |
| 392 | tokenizer, |
| 393 | bert_config, |
| 394 | squad_lib, |
| 395 | init_checkpoint=None): |
| 396 | """Get prediction results and evaluate them to hard drive.""" |
| 397 | if init_checkpoint is None: |
| 398 | init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir) |
| 399 | |
| 400 | all_predict_files = _get_matched_files(FLAGS.predict_file) |
| 401 | squad_model = get_squad_model_to_predict(strategy, bert_config, |
| 402 | init_checkpoint, input_meta_data) |
| 403 | for idx, predict_file in enumerate(all_predict_files): |
| 404 | all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad( |
| 405 | strategy, input_meta_data, tokenizer, squad_lib, predict_file, |
| 406 | squad_model) |
| 407 | if len(all_predict_files) == 1: |
| 408 | file_prefix = '' |
| 409 | else: |
| 410 | # if predict_file is /path/xquad.ar.json, the `file_prefix` may be |
| 411 | # "xquad.ar-0-" |
| 412 | file_prefix = '%s-' % os.path.splitext( |
| 413 | os.path.basename(all_predict_files[idx]))[0] |
| 414 | dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib, |
| 415 | input_meta_data.get('version_2_with_negative', False), |
| 416 | file_prefix) |
| 417 | |
| 418 | |
| 419 | def eval_squad(strategy, |
nothing calls this directly
no test coverage detected