MCPcopy Index your code
hub / github.com/tensorflow/models / test_actions

Method test_actions

orbit/controller_test.py:784–828  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

782 "accuracy", os.path.join(self.model_dir, "dataset2")))
783
784 def test_actions(self):
785 test_runner = TestRunner()
786 checkpoint = tf.train.Checkpoint(
787 model=test_runner.model, optimizer=test_runner.optimizer)
788 checkpoint_manager = tf.train.CheckpointManager(
789 checkpoint,
790 self.model_dir,
791 max_to_keep=None,
792 step_counter=test_runner.global_step,
793 checkpoint_interval=10)
794
795 class OutputRecorderAction:
796 """Simple `Action` that just saves the outputs passed to `__call__`."""
797
798 def __init__(self):
799 self.outputs = []
800
801 def __call__(self, output):
802 self.outputs.append(output)
803
804 train_output_recorder = OutputRecorderAction()
805 eval_output_recorder = OutputRecorderAction()
806
807 test_controller = controller.Controller(
808 trainer=test_runner,
809 evaluator=test_runner,
810 train_actions=[train_output_recorder],
811 eval_actions=[eval_output_recorder],
812 global_step=test_runner.global_step,
813 steps_per_loop=2,
814 summary_dir=os.path.join(self.model_dir, "summaries/train"),
815 checkpoint_manager=checkpoint_manager,
816 eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
817 test_controller.train_and_evaluate(
818 train_steps=10, eval_steps=2, eval_interval=6)
819
820 self.assertLen(train_output_recorder.outputs, 5)
821 for output in train_output_recorder.outputs:
822 self.assertIn("loss", output)
823 self.assertGreaterEqual(output["loss"], 0)
824
825 self.assertLen(eval_output_recorder.outputs, 2)
826 for output in eval_output_recorder.outputs:
827 self.assertIn("eval_loss", output)
828 self.assertGreaterEqual(output["eval_loss"], 0)
829
830 def test_step_per_loop_callable(self):
831 test_runner = TestRunner()

Callers

nothing calls this directly

Calls 4

train_and_evaluateMethod · 0.95
TestRunnerClass · 0.85
joinMethod · 0.45

Tested by

no test coverage detected