(self)
| 782 | "accuracy", os.path.join(self.model_dir, "dataset2"))) |
| 783 | |
| 784 | def test_actions(self): |
| 785 | test_runner = TestRunner() |
| 786 | checkpoint = tf.train.Checkpoint( |
| 787 | model=test_runner.model, optimizer=test_runner.optimizer) |
| 788 | checkpoint_manager = tf.train.CheckpointManager( |
| 789 | checkpoint, |
| 790 | self.model_dir, |
| 791 | max_to_keep=None, |
| 792 | step_counter=test_runner.global_step, |
| 793 | checkpoint_interval=10) |
| 794 | |
| 795 | class OutputRecorderAction: |
| 796 | """Simple `Action` that just saves the outputs passed to `__call__`.""" |
| 797 | |
| 798 | def __init__(self): |
| 799 | self.outputs = [] |
| 800 | |
| 801 | def __call__(self, output): |
| 802 | self.outputs.append(output) |
| 803 | |
| 804 | train_output_recorder = OutputRecorderAction() |
| 805 | eval_output_recorder = OutputRecorderAction() |
| 806 | |
| 807 | test_controller = controller.Controller( |
| 808 | trainer=test_runner, |
| 809 | evaluator=test_runner, |
| 810 | train_actions=[train_output_recorder], |
| 811 | eval_actions=[eval_output_recorder], |
| 812 | global_step=test_runner.global_step, |
| 813 | steps_per_loop=2, |
| 814 | summary_dir=os.path.join(self.model_dir, "summaries/train"), |
| 815 | checkpoint_manager=checkpoint_manager, |
| 816 | eval_summary_dir=os.path.join(self.model_dir, "summaries/eval")) |
| 817 | test_controller.train_and_evaluate( |
| 818 | train_steps=10, eval_steps=2, eval_interval=6) |
| 819 | |
| 820 | self.assertLen(train_output_recorder.outputs, 5) |
| 821 | for output in train_output_recorder.outputs: |
| 822 | self.assertIn("loss", output) |
| 823 | self.assertGreaterEqual(output["loss"], 0) |
| 824 | |
| 825 | self.assertLen(eval_output_recorder.outputs, 2) |
| 826 | for output in eval_output_recorder.outputs: |
| 827 | self.assertIn("eval_loss", output) |
| 828 | self.assertGreaterEqual(output["eval_loss"], 0) |
| 829 | |
| 830 | def test_step_per_loop_callable(self): |
| 831 | test_runner = TestRunner() |
nothing calls this directly
no test coverage detected