Run classification or regression. Args: custom_callbacks: list of tf_keras.Callbacks passed to training loop. custom_metrics: list of metrics passed to the training loop.
(custom_callbacks=None, custom_metrics=None)
| 418 | |
| 419 | |
| 420 | def custom_main(custom_callbacks=None, custom_metrics=None): |
| 421 | """Run classification or regression. |
| 422 | |
| 423 | Args: |
| 424 | custom_callbacks: list of tf_keras.Callbacks passed to training loop. |
| 425 | custom_metrics: list of metrics passed to the training loop. |
| 426 | """ |
| 427 | gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) |
| 428 | |
| 429 | with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: |
| 430 | input_meta_data = json.loads(reader.read().decode('utf-8')) |
| 431 | label_type = LABEL_TYPES_MAP[input_meta_data.get('label_type', 'int')] |
| 432 | include_sample_weights = input_meta_data.get('has_sample_weights', False) |
| 433 | |
| 434 | if not FLAGS.model_dir: |
| 435 | FLAGS.model_dir = '/tmp/bert20/' |
| 436 | |
| 437 | bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file) |
| 438 | |
| 439 | if FLAGS.mode == 'export_only': |
| 440 | export_classifier(FLAGS.model_export_path, input_meta_data, bert_config, |
| 441 | FLAGS.model_dir) |
| 442 | return |
| 443 | |
| 444 | strategy = distribute_utils.get_distribution_strategy( |
| 445 | distribution_strategy=FLAGS.distribution_strategy, |
| 446 | num_gpus=FLAGS.num_gpus, |
| 447 | tpu_address=FLAGS.tpu) |
| 448 | eval_input_fn = get_dataset_fn( |
| 449 | FLAGS.eval_data_path, |
| 450 | input_meta_data['max_seq_length'], |
| 451 | FLAGS.eval_batch_size, |
| 452 | is_training=False, |
| 453 | label_type=label_type, |
| 454 | include_sample_weights=include_sample_weights) |
| 455 | |
| 456 | if FLAGS.mode == 'predict': |
| 457 | num_labels = input_meta_data.get('num_labels', 1) |
| 458 | with strategy.scope(): |
| 459 | classifier_model = bert_models.classifier_model( |
| 460 | bert_config, num_labels)[0] |
| 461 | checkpoint = tf.train.Checkpoint(model=classifier_model) |
| 462 | latest_checkpoint_file = ( |
| 463 | FLAGS.predict_checkpoint_path or |
| 464 | tf.train.latest_checkpoint(FLAGS.model_dir)) |
| 465 | assert latest_checkpoint_file |
| 466 | logging.info('Checkpoint file %s found and restoring from ' |
| 467 | 'checkpoint', latest_checkpoint_file) |
| 468 | checkpoint.restore( |
| 469 | latest_checkpoint_file).assert_existing_objects_matched() |
| 470 | preds, _ = get_predictions_and_labels( |
| 471 | strategy, |
| 472 | classifier_model, |
| 473 | eval_input_fn, |
| 474 | is_regression=(num_labels == 1), |
| 475 | return_probs=True) |
| 476 | output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv') |
| 477 | with tf.io.gfile.GFile(output_predict_file, 'w') as writer: |
no test coverage detected