(self)
| 828 | self.assertGreaterEqual(output["eval_loss"], 0) |
| 829 | |
| 830 | def test_step_per_loop_callable(self): |
| 831 | test_runner = TestRunner() |
| 832 | |
| 833 | checkpoint = tf.train.Checkpoint( |
| 834 | model=test_runner.model, optimizer=test_runner.optimizer) |
| 835 | checkpoint_manager = tf.train.CheckpointManager( |
| 836 | checkpoint, |
| 837 | self.model_dir, |
| 838 | max_to_keep=None, |
| 839 | step_counter=test_runner.global_step, |
| 840 | checkpoint_interval=10) |
| 841 | |
| 842 | def steps_per_loop_fn(global_step): |
| 843 | if global_step > 4: |
| 844 | return 4 |
| 845 | return 2 |
| 846 | |
| 847 | test_controller = controller.Controller( |
| 848 | trainer=test_runner, |
| 849 | global_step=test_runner.global_step, |
| 850 | steps_per_loop=steps_per_loop_fn, |
| 851 | checkpoint_manager=checkpoint_manager |
| 852 | ) |
| 853 | test_controller.train(steps=10) |
| 854 | self.assertEqual(test_runner.global_step, 10) |
| 855 | |
| 856 | |
| 857 | if __name__ == "__main__": |
nothing calls this directly
no test coverage detected