MCPcopy Index your code
hub / github.com/tensorflow/models / test_train_only

Method test_train_only

orbit/controller_test.py:439–471  ·  view source on GitHub ↗
(self, enable_async_checkpoint_saving)

Source from the content-addressed store, hash-verified

437 ("_async_checkpoint_saving", True)
438 )
439 def test_train_only(self, enable_async_checkpoint_saving):
440 test_runner = TestRunner()
441
442 checkpoint = tf.train.Checkpoint(
443 model=test_runner.model, optimizer=test_runner.optimizer)
444 checkpoint_manager = tf.train.CheckpointManager(
445 checkpoint,
446 self.model_dir,
447 max_to_keep=None,
448 step_counter=test_runner.global_step,
449 checkpoint_interval=10)
450 test_controller = controller.Controller(
451 trainer=test_runner,
452 global_step=test_runner.global_step,
453 steps_per_loop=2,
454 summary_dir=os.path.join(self.model_dir, "summaries/train"),
455 checkpoint_manager=checkpoint_manager,
456 enable_async_checkpointing=enable_async_checkpoint_saving,
457 eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
458 )
459 test_controller.train(steps=10)
460
461 # Checkpoints are saved.
462 self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
463
464 # Only train summaries are written.
465 self.assertNotEmpty(
466 tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
467 self.assertNotEmpty(
468 summaries_with_matching_keyword(
469 "loss", os.path.join(self.model_dir, "summaries/train")))
470 self.assertFalse(
471 tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))
472
473 def test_evaluate_only(self):
474 test_runner = TestRunner()

Callers

nothing calls this directly

Calls 4

trainMethod · 0.95
TestRunnerClass · 0.85
joinMethod · 0.45

Tested by

no test coverage detected