Merge CLI arguments to config.
(cfg, args)
| 83 | |
| 84 | |
| 85 | def merge_args(cfg, args): |
| 86 | """Merge CLI arguments to config.""" |
| 87 | cfg.launcher = args.launcher |
| 88 | |
| 89 | # work_dir is determined in this priority: CLI > segment in file > filename |
| 90 | if args.work_dir is not None: |
| 91 | # update configs according to CLI args if args.work_dir is not None |
| 92 | cfg.work_dir = args.work_dir |
| 93 | elif cfg.get('work_dir', None) is None: |
| 94 | # use config filename as default work_dir if cfg.work_dir is None |
| 95 | cfg.work_dir = osp.join('./work_dirs', |
| 96 | osp.splitext(osp.basename(args.config))[0]) |
| 97 | |
| 98 | cfg.load_from = args.checkpoint |
| 99 | |
| 100 | # enable automatic-mixed-precision test |
| 101 | if args.amp: |
| 102 | cfg.test_cfg.fp16 = True |
| 103 | |
| 104 | # -------------------- visualization -------------------- |
| 105 | if args.show or (args.show_dir is not None): |
| 106 | assert 'visualization' in cfg.default_hooks, \ |
| 107 | 'VisualizationHook is not set in the `default_hooks` field of ' \ |
| 108 | 'config. Please set `visualization=dict(type="VisualizationHook")`' |
| 109 | |
| 110 | cfg.default_hooks.visualization.enable = True |
| 111 | cfg.default_hooks.visualization.show = args.show |
| 112 | cfg.default_hooks.visualization.wait_time = args.wait_time |
| 113 | cfg.default_hooks.visualization.out_dir = args.show_dir |
| 114 | cfg.default_hooks.visualization.interval = args.interval |
| 115 | |
| 116 | # -------------------- TTA related args -------------------- |
| 117 | if args.tta: |
| 118 | if 'tta_model' not in cfg: |
| 119 | cfg.tta_model = dict(type='mmpretrain.AverageClsScoreTTA') |
| 120 | if 'tta_pipeline' not in cfg: |
| 121 | test_pipeline = cfg.test_dataloader.dataset.pipeline |
| 122 | cfg.tta_pipeline = deepcopy(test_pipeline) |
| 123 | flip_tta = dict( |
| 124 | type='TestTimeAug', |
| 125 | transforms=[ |
| 126 | [ |
| 127 | dict(type='RandomFlip', prob=1.), |
| 128 | dict(type='RandomFlip', prob=0.) |
| 129 | ], |
| 130 | [test_pipeline[-1]], |
| 131 | ]) |
| 132 | cfg.tta_pipeline[-1] = flip_tta |
| 133 | cfg.model = ConfigDict(**cfg.tta_model, module=cfg.model) |
| 134 | cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline |
| 135 | |
| 136 | # ----------------- Default dataloader args ----------------- |
| 137 | default_dataloader_cfg = ConfigDict( |
| 138 | pin_memory=True, |
| 139 | collate_fn=dict(type='default_collate'), |
| 140 | ) |
| 141 | |
| 142 | def set_default_dataloader_cfg(cfg, field): |
no test coverage detected