A subclass of Controller that supports early stopping.
| 623 | test_runner = TestRunner() |
| 624 | |
| 625 | class EarlyStopController(controller.Controller): |
| 626 | """A subclass of Controller that supports early stopping.""" |
| 627 | |
| 628 | def train_and_evaluate(self, |
| 629 | train_steps: int = None, |
| 630 | eval_steps: int = None, |
| 631 | eval_interval: int = None): |
| 632 | while self.global_step.numpy() < train_steps: |
| 633 | interval = min(train_steps - self.global_step.numpy(), eval_interval) |
| 634 | num_steps = self.global_step.numpy() + interval |
| 635 | self.train(steps=num_steps, checkpoint_at_completion=False) |
| 636 | self._sync_on_async_checkpointing() |
| 637 | self.evaluate(steps=eval_steps) |
| 638 | # Early stop condition. |
| 639 | if test_runner.eval_loss.result() < 0.1: |
| 640 | logging.info( |
| 641 | "Training early stopped as eval_loss %s is less than 0.1", |
| 642 | test_runner.eval_loss.result()) |
| 643 | return |
| 644 | |
| 645 | checkpoint = tf.train.Checkpoint( |
| 646 | model=test_runner.model, optimizer=test_runner.optimizer) |
no outgoing calls