(self, enable_async_checkpoint_saving)
| 437 | ("_async_checkpoint_saving", True) |
| 438 | ) |
| 439 | def test_train_only(self, enable_async_checkpoint_saving): |
| 440 | test_runner = TestRunner() |
| 441 | |
| 442 | checkpoint = tf.train.Checkpoint( |
| 443 | model=test_runner.model, optimizer=test_runner.optimizer) |
| 444 | checkpoint_manager = tf.train.CheckpointManager( |
| 445 | checkpoint, |
| 446 | self.model_dir, |
| 447 | max_to_keep=None, |
| 448 | step_counter=test_runner.global_step, |
| 449 | checkpoint_interval=10) |
| 450 | test_controller = controller.Controller( |
| 451 | trainer=test_runner, |
| 452 | global_step=test_runner.global_step, |
| 453 | steps_per_loop=2, |
| 454 | summary_dir=os.path.join(self.model_dir, "summaries/train"), |
| 455 | checkpoint_manager=checkpoint_manager, |
| 456 | enable_async_checkpointing=enable_async_checkpoint_saving, |
| 457 | eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"), |
| 458 | ) |
| 459 | test_controller.train(steps=10) |
| 460 | |
| 461 | # Checkpoints are saved. |
| 462 | self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*"))) |
| 463 | |
| 464 | # Only train summaries are written. |
| 465 | self.assertNotEmpty( |
| 466 | tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train"))) |
| 467 | self.assertNotEmpty( |
| 468 | summaries_with_matching_keyword( |
| 469 | "loss", os.path.join(self.model_dir, "summaries/train"))) |
| 470 | self.assertFalse( |
| 471 | tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval"))) |
| 472 | |
| 473 | def test_evaluate_only(self): |
| 474 | test_runner = TestRunner() |
nothing calls this directly
no test coverage detected