MCPcopy
hub / github.com/alibaba/EasyCV / _base_train

Method _base_train

tests/test_tools/test_pose_train.py:66–95  ·  view source on GitHub ↗
(self, train_cfgs)

Source from the content-addressed store, hash-verified

64 super().tearDown()
65
66 def _base_train(self, train_cfgs):
67 io.access_oss()
68
69 cfg_file = train_cfgs.pop('config_file')
70 cfg_options = train_cfgs.pop('cfg_options', None)
71 work_dir = train_cfgs.pop('work_dir', None)
72 if not work_dir:
73 work_dir = tempfile.TemporaryDirectory().name
74
75 cfg = Config.fromfile(cfg_file)
76 if cfg_options is not None:
77 cfg.merge_from_dict(cfg_options)
78
79 cfg.eval_pipelines[0].data = dict(**cfg.data.val, imgs_per_gpu=1)
80
81 tmp_cfg_file = tempfile.NamedTemporaryFile(suffix='.py').name
82 cfg.dump(tmp_cfg_file)
83
84 args_str = ' '.join(
85 ['='.join((str(k), str(v))) for k, v in train_cfgs.items()])
86 cmd = 'python tools/train.py %s --work_dir=%s %s --fp16' % \
87 (tmp_cfg_file, work_dir, args_str)
88
89 run_in_subprocess(cmd)
90
91 output_files = io.listdir(work_dir)
92 self.assertIn('epoch_1.pth', output_files)
93
94 io.remove(work_dir)
95 io.remove(tmp_cfg_file)
96
97 # def test_litehrnet(self):
98 # train_cfgs = copy.deepcopy(TRAIN_CONFIGS[0])

Callers 2

test_litehrnet_ossMethod · 0.95
test_hrnetMethod · 0.95

Calls 5

run_in_subprocessFunction · 0.90
access_ossMethod · 0.80
dumpMethod · 0.80
listdirMethod · 0.45
removeMethod · 0.45

Tested by

no test coverage detected