(self)
| 620 | "eval_loss", os.path.join(self.model_dir, "summaries"))) |
| 621 | |
| 622 | def test_early_stop_on_eval_loss(self): |
| 623 | test_runner = TestRunner() |
| 624 | |
| 625 | class EarlyStopController(controller.Controller): |
| 626 | """A subclass of Controller that supports early stopping.""" |
| 627 | |
| 628 | def train_and_evaluate(self, |
| 629 | train_steps: int = None, |
| 630 | eval_steps: int = None, |
| 631 | eval_interval: int = None): |
| 632 | while self.global_step.numpy() < train_steps: |
| 633 | interval = min(train_steps - self.global_step.numpy(), eval_interval) |
| 634 | num_steps = self.global_step.numpy() + interval |
| 635 | self.train(steps=num_steps, checkpoint_at_completion=False) |
| 636 | self._sync_on_async_checkpointing() |
| 637 | self.evaluate(steps=eval_steps) |
| 638 | # Early stop condition. |
| 639 | if test_runner.eval_loss.result() < 0.1: |
| 640 | logging.info( |
| 641 | "Training early stopped as eval_loss %s is less than 0.1", |
| 642 | test_runner.eval_loss.result()) |
| 643 | return |
| 644 | |
| 645 | checkpoint = tf.train.Checkpoint( |
| 646 | model=test_runner.model, optimizer=test_runner.optimizer) |
| 647 | checkpoint_manager = tf.train.CheckpointManager( |
| 648 | checkpoint, |
| 649 | self.model_dir, |
| 650 | max_to_keep=None, |
| 651 | step_counter=test_runner.global_step, |
| 652 | checkpoint_interval=10) |
| 653 | test_controller = EarlyStopController( |
| 654 | trainer=test_runner, |
| 655 | evaluator=test_runner, |
| 656 | global_step=test_runner.global_step, |
| 657 | steps_per_loop=2, |
| 658 | checkpoint_manager=checkpoint_manager) |
| 659 | test_controller.train_and_evaluate( |
| 660 | train_steps=10, eval_steps=6, eval_interval=2) |
| 661 | |
| 662 | self.assertLess(test_runner.global_step, 10) |
| 663 | |
| 664 | def test_evaluate_with_loss_output(self): |
| 665 | test_evaluator = TestEvaluator() |
nothing calls this directly
no test coverage detected