MCPcopy
hub / github.com/tensorflow/models / custom_main

Function custom_main

official/legacy/bert/run_classifier.py:420–503  ·  view source on GitHub ↗

Run classification or regression. Args: custom_callbacks: list of tf_keras.Callbacks passed to training loop. custom_metrics: list of metrics passed to the training loop.

(custom_callbacks=None, custom_metrics=None)

Source from the content-addressed store, hash-verified

418
419
420def custom_main(custom_callbacks=None, custom_metrics=None):
421 """Run classification or regression.
422
423 Args:
424 custom_callbacks: list of tf_keras.Callbacks passed to training loop.
425 custom_metrics: list of metrics passed to the training loop.
426 """
427 gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
428
429 with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
430 input_meta_data = json.loads(reader.read().decode('utf-8'))
431 label_type = LABEL_TYPES_MAP[input_meta_data.get('label_type', 'int')]
432 include_sample_weights = input_meta_data.get('has_sample_weights', False)
433
434 if not FLAGS.model_dir:
435 FLAGS.model_dir = '/tmp/bert20/'
436
437 bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
438
439 if FLAGS.mode == 'export_only':
440 export_classifier(FLAGS.model_export_path, input_meta_data, bert_config,
441 FLAGS.model_dir)
442 return
443
444 strategy = distribute_utils.get_distribution_strategy(
445 distribution_strategy=FLAGS.distribution_strategy,
446 num_gpus=FLAGS.num_gpus,
447 tpu_address=FLAGS.tpu)
448 eval_input_fn = get_dataset_fn(
449 FLAGS.eval_data_path,
450 input_meta_data['max_seq_length'],
451 FLAGS.eval_batch_size,
452 is_training=False,
453 label_type=label_type,
454 include_sample_weights=include_sample_weights)
455
456 if FLAGS.mode == 'predict':
457 num_labels = input_meta_data.get('num_labels', 1)
458 with strategy.scope():
459 classifier_model = bert_models.classifier_model(
460 bert_config, num_labels)[0]
461 checkpoint = tf.train.Checkpoint(model=classifier_model)
462 latest_checkpoint_file = (
463 FLAGS.predict_checkpoint_path or
464 tf.train.latest_checkpoint(FLAGS.model_dir))
465 assert latest_checkpoint_file
466 logging.info('Checkpoint file %s found and restoring from '
467 'checkpoint', latest_checkpoint_file)
468 checkpoint.restore(
469 latest_checkpoint_file).assert_existing_objects_matched()
470 preds, _ = get_predictions_and_labels(
471 strategy,
472 classifier_model,
473 eval_input_fn,
474 is_regression=(num_labels == 1),
475 return_probs=True)
476 output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv')
477 with tf.io.gfile.GFile(output_predict_file, 'w') as writer:

Callers 1

mainFunction · 0.85

Calls 11

export_classifierFunction · 0.85
run_bertFunction · 0.85
infoMethod · 0.80
writeMethod · 0.80
get_dataset_fnFunction · 0.70
decodeMethod · 0.45
readMethod · 0.45
getMethod · 0.45
from_json_fileMethod · 0.45
joinMethod · 0.45

Tested by

no test coverage detected