MCPcopy Index your code
hub / github.com/tensorflow/models / test_early_stop_on_eval_loss

Method test_early_stop_on_eval_loss

orbit/controller_test.py:622–662  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

620 "eval_loss", os.path.join(self.model_dir, "summaries")))
621
622 def test_early_stop_on_eval_loss(self):
623 test_runner = TestRunner()
624
625 class EarlyStopController(controller.Controller):
626 """A subclass of Controller that supports early stopping."""
627
628 def train_and_evaluate(self,
629 train_steps: int = None,
630 eval_steps: int = None,
631 eval_interval: int = None):
632 while self.global_step.numpy() < train_steps:
633 interval = min(train_steps - self.global_step.numpy(), eval_interval)
634 num_steps = self.global_step.numpy() + interval
635 self.train(steps=num_steps, checkpoint_at_completion=False)
636 self._sync_on_async_checkpointing()
637 self.evaluate(steps=eval_steps)
638 # Early stop condition.
639 if test_runner.eval_loss.result() < 0.1:
640 logging.info(
641 "Training early stopped as eval_loss %s is less than 0.1",
642 test_runner.eval_loss.result())
643 return
644
645 checkpoint = tf.train.Checkpoint(
646 model=test_runner.model, optimizer=test_runner.optimizer)
647 checkpoint_manager = tf.train.CheckpointManager(
648 checkpoint,
649 self.model_dir,
650 max_to_keep=None,
651 step_counter=test_runner.global_step,
652 checkpoint_interval=10)
653 test_controller = EarlyStopController(
654 trainer=test_runner,
655 evaluator=test_runner,
656 global_step=test_runner.global_step,
657 steps_per_loop=2,
658 checkpoint_manager=checkpoint_manager)
659 test_controller.train_and_evaluate(
660 train_steps=10, eval_steps=6, eval_interval=2)
661
662 self.assertLess(test_runner.global_step, 10)
663
664 def test_evaluate_with_loss_output(self):
665 test_evaluator = TestEvaluator()

Callers

nothing calls this directly

Calls 3

train_and_evaluateMethod · 0.95
TestRunnerClass · 0.85
EarlyStopControllerClass · 0.85

Tested by

no test coverage detected