()
| 114 | |
| 115 | |
| 116 | def main(): |
| 117 | args = parse_args() |
| 118 | |
| 119 | if args.model_type is not None: |
| 120 | assert args.model_type in CONFIG_TEMPLATE_ZOO, 'model_type must be in [%s]' % ( |
| 121 | ', '.join(CONFIG_TEMPLATE_ZOO.keys())) |
| 122 | print('model_type=%s, config file will be replaced by %s' % |
| 123 | (args.model_type, CONFIG_TEMPLATE_ZOO[args.model_type])) |
| 124 | args.config = CONFIG_TEMPLATE_ZOO[args.model_type] |
| 125 | |
| 126 | if args.config.startswith('http'): |
| 127 | |
| 128 | r = requests.get(args.config) |
| 129 | # download config in current dir |
| 130 | tpath = args.config.split('/')[-1] |
| 131 | while not osp.exists(tpath): |
| 132 | try: |
| 133 | with open(tpath, 'wb') as code: |
| 134 | code.write(r.content) |
| 135 | except: |
| 136 | pass |
| 137 | |
| 138 | args.config = tpath |
| 139 | |
| 140 | # build cfg |
| 141 | if args.user_config_params is None: |
| 142 | cfg = mmcv_config_fromfile(args.config) |
| 143 | else: |
| 144 | cfg = pai_config_fromfile(args.config, args.user_config_params, |
| 145 | args.model_type) |
| 146 | |
| 147 | # set multi-process settings |
| 148 | setup_multi_processes(cfg) |
| 149 | |
| 150 | # set cudnn_benchmark |
| 151 | if cfg.get('cudnn_benchmark', False): |
| 152 | torch.backends.cudnn.benchmark = True |
| 153 | |
| 154 | # update configs according to CLI args |
| 155 | if args.work_dir is not None: |
| 156 | cfg.work_dir = args.work_dir |
| 157 | |
| 158 | if cfg.get('work_dir', None) is None: |
| 159 | cfg.work_dir = './work_dir' |
| 160 | |
| 161 | # if `work_dir` is oss path, redirect `work_dir` to local path, add `oss_work_dir` point to oss path, |
| 162 | # and use osssync hook to upload log and ckpt in work_dir to oss_work_dir |
| 163 | if cfg.work_dir.startswith('oss://'): |
| 164 | cfg.oss_work_dir = cfg.work_dir |
| 165 | cfg.work_dir = osp.join('work_dirs', |
| 166 | cfg.work_dir.replace('oss://', '')) |
| 167 | else: |
| 168 | cfg.oss_work_dir = None |
| 169 | |
| 170 | if args.resume_from is not None and len(args.resume_from) > 0: |
| 171 | cfg.resume_from = args.resume_from |
| 172 | if args.load_from is not None and len(args.load_from) > 0: |
| 173 | cfg.load_from = args.load_from |
no test coverage detected