MCPcopy Index your code
hub / github.com/tensorflow/models / test_step_per_loop_callable

Method test_step_per_loop_callable

orbit/controller_test.py:830–854  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

828 self.assertGreaterEqual(output["eval_loss"], 0)
829
830 def test_step_per_loop_callable(self):
831 test_runner = TestRunner()
832
833 checkpoint = tf.train.Checkpoint(
834 model=test_runner.model, optimizer=test_runner.optimizer)
835 checkpoint_manager = tf.train.CheckpointManager(
836 checkpoint,
837 self.model_dir,
838 max_to_keep=None,
839 step_counter=test_runner.global_step,
840 checkpoint_interval=10)
841
842 def steps_per_loop_fn(global_step):
843 if global_step > 4:
844 return 4
845 return 2
846
847 test_controller = controller.Controller(
848 trainer=test_runner,
849 global_step=test_runner.global_step,
850 steps_per_loop=steps_per_loop_fn,
851 checkpoint_manager=checkpoint_manager
852 )
853 test_controller.train(steps=10)
854 self.assertEqual(test_runner.global_step, 10)
855
856
857if __name__ == "__main__":

Callers

nothing calls this directly

Calls 2

trainMethod · 0.95
TestRunnerClass · 0.85

Tested by

no test coverage detected