MCPcopy
hub / github.com/open-mmlab/mmpretrain / merge_args

Function merge_args

tools/test.py:85–156  ·  view source on GitHub ↗

Merge CLI arguments to config.

(cfg, args)

Source from the content-addressed store, hash-verified

83
84
85def merge_args(cfg, args):
86 """Merge CLI arguments to config."""
87 cfg.launcher = args.launcher
88
89 # work_dir is determined in this priority: CLI > segment in file > filename
90 if args.work_dir is not None:
91 # update configs according to CLI args if args.work_dir is not None
92 cfg.work_dir = args.work_dir
93 elif cfg.get('work_dir', None) is None:
94 # use config filename as default work_dir if cfg.work_dir is None
95 cfg.work_dir = osp.join('./work_dirs',
96 osp.splitext(osp.basename(args.config))[0])
97
98 cfg.load_from = args.checkpoint
99
100 # enable automatic-mixed-precision test
101 if args.amp:
102 cfg.test_cfg.fp16 = True
103
104 # -------------------- visualization --------------------
105 if args.show or (args.show_dir is not None):
106 assert 'visualization' in cfg.default_hooks, \
107 'VisualizationHook is not set in the `default_hooks` field of ' \
108 'config. Please set `visualization=dict(type="VisualizationHook")`'
109
110 cfg.default_hooks.visualization.enable = True
111 cfg.default_hooks.visualization.show = args.show
112 cfg.default_hooks.visualization.wait_time = args.wait_time
113 cfg.default_hooks.visualization.out_dir = args.show_dir
114 cfg.default_hooks.visualization.interval = args.interval
115
116 # -------------------- TTA related args --------------------
117 if args.tta:
118 if 'tta_model' not in cfg:
119 cfg.tta_model = dict(type='mmpretrain.AverageClsScoreTTA')
120 if 'tta_pipeline' not in cfg:
121 test_pipeline = cfg.test_dataloader.dataset.pipeline
122 cfg.tta_pipeline = deepcopy(test_pipeline)
123 flip_tta = dict(
124 type='TestTimeAug',
125 transforms=[
126 [
127 dict(type='RandomFlip', prob=1.),
128 dict(type='RandomFlip', prob=0.)
129 ],
130 [test_pipeline[-1]],
131 ])
132 cfg.tta_pipeline[-1] = flip_tta
133 cfg.model = ConfigDict(**cfg.tta_model, module=cfg.model)
134 cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline
135
136 # ----------------- Default dataloader args -----------------
137 default_dataloader_cfg = ConfigDict(
138 pin_memory=True,
139 collate_fn=dict(type='default_collate'),
140 )
141
142 def set_default_dataloader_cfg(cfg, field):

Callers 1

mainFunction · 0.70

Calls 2

getMethod · 0.80

Tested by

no test coverage detected