(self, builder)
| 66 | |
| 67 | class TestCheckpoint(TestCase): |
| 68 | def run_with(self, builder): |
| 69 | with Cluster(): |
| 70 | with Job() as job: |
| 71 | outputs = build_pipeline(node_id=0) |
| 72 | output_fetcher = Task(step=core.Net('empty'), outputs=outputs) |
| 73 | |
| 74 | def fetch_total(session): |
| 75 | session.run(output_fetcher) |
| 76 | return output_fetcher.outputs()[0].fetch() |
| 77 | |
| 78 | session, checkpoint = builder() |
| 79 | job.compile(LocalSession) |
| 80 | num_epochs = JobRunner(job, checkpoint).train(session) |
| 81 | self.assertEqual(num_epochs, len(EXPECTED_TOTALS)) |
| 82 | self.assertEqual(fetch_total(session), EXPECTED_TOTALS[-1]) |
| 83 | |
| 84 | for initial_epoch in range(1, num_epochs + 1): |
| 85 | session, checkpoint = builder() |
| 86 | JobRunner( |
| 87 | job, |
| 88 | checkpoint, resume_from_epoch=initial_epoch |
| 89 | ).train(session) |
| 90 | self.assertEqual(fetch_total(session), EXPECTED_TOTALS[-1]) |
| 91 | |
| 92 | for epoch in range(1, num_epochs + 1): |
| 93 | session.run(checkpoint.load(epoch)) |
| 94 | self.assertEqual(fetch_total(session), |
| 95 | EXPECTED_TOTALS[epoch - 1]) |
| 96 | |
| 97 | def test_single_checkpoint(self): |
| 98 | # test single node |
no test coverage detected