(self)
| 24 | print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) |
| 25 | |
| 26 | def test_export_hook(self): |
| 27 | model = _build_model() |
| 28 | optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95) |
| 29 | log_config = dict( |
| 30 | interval=1, |
| 31 | hooks=[ |
| 32 | dict(type='TextLoggerHook'), |
| 33 | dict(type='TensorboardLoggerHook'), |
| 34 | ]) |
| 35 | checkpoint_config = dict(interval=1) |
| 36 | |
| 37 | work_dir = get_tmp_dir() |
| 38 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) |
| 39 | log_file = os.path.join(work_dir, '{}.log'.format(timestamp)) |
| 40 | logger = get_root_logger(log_file=log_file) |
| 41 | runner = EVRunner( |
| 42 | model=model, work_dir=work_dir, optimizer=optimizer, logger=logger) |
| 43 | |
| 44 | runner.register_logger_hooks(log_config) |
| 45 | runner.register_checkpoint_hook(checkpoint_config) |
| 46 | |
| 47 | cfg = Config(cfg_dict=dict(model=dict(type='TestExportModel'))) |
| 48 | cfg.work_dir = work_dir |
| 49 | hook = ExportHook(cfg) |
| 50 | loader = DataLoader(torch.ones((3, 2))) |
| 51 | |
| 52 | runner.register_hook(hook) |
| 53 | runner.run([loader], [('train', 1)], 2) |
| 54 | files_list = io.listdir(work_dir) |
| 55 | self.assertIn('epoch_2_export.pt', files_list) |
| 56 | self.assertIn('epoch_1.pth', files_list) |
| 57 | self.assertIn('epoch_2.pth', files_list) |
| 58 | |
| 59 | io.rmtree(work_dir) |
| 60 | |
| 61 | |
| 62 | def _build_model(): |
nothing calls this directly
no test coverage detected