MCPcopy
hub / github.com/alibaba/EasyCV / test_export_hook

Method test_export_hook

tests/test_hooks/test_export_hook.py:26–59  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

24 print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
25
26 def test_export_hook(self):
27 model = _build_model()
28 optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)
29 log_config = dict(
30 interval=1,
31 hooks=[
32 dict(type='TextLoggerHook'),
33 dict(type='TensorboardLoggerHook'),
34 ])
35 checkpoint_config = dict(interval=1)
36
37 work_dir = get_tmp_dir()
38 timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
39 log_file = os.path.join(work_dir, '{}.log'.format(timestamp))
40 logger = get_root_logger(log_file=log_file)
41 runner = EVRunner(
42 model=model, work_dir=work_dir, optimizer=optimizer, logger=logger)
43
44 runner.register_logger_hooks(log_config)
45 runner.register_checkpoint_hook(checkpoint_config)
46
47 cfg = Config(cfg_dict=dict(model=dict(type='TestExportModel')))
48 cfg.work_dir = work_dir
49 hook = ExportHook(cfg)
50 loader = DataLoader(torch.ones((3, 2)))
51
52 runner.register_hook(hook)
53 runner.run([loader], [('train', 1)], 2)
54 files_list = io.listdir(work_dir)
55 self.assertIn('epoch_2_export.pt', files_list)
56 self.assertIn('epoch_1.pth', files_list)
57 self.assertIn('epoch_2.pth', files_list)
58
59 io.rmtree(work_dir)
60
61
62def _build_model():

Callers

nothing calls this directly

Calls 8

get_tmp_dirFunction · 0.90
get_root_loggerFunction · 0.90
EVRunnerClass · 0.90
ExportHookClass · 0.90
ConfigClass · 0.85
_build_modelFunction · 0.70
listdirMethod · 0.45
rmtreeMethod · 0.45

Tested by

no test coverage detected