()
| 73 | |
| 74 | |
| 75 | def main(): |
| 76 | args = parse_args() |
| 77 | |
| 78 | if args.model_type is not None and args.config is None: |
| 79 | assert args.model_type in CONFIG_TEMPLATE_ZOO, 'model_type must be in [%s]' % ( |
| 80 | ', '.join(CONFIG_TEMPLATE_ZOO.keys())) |
| 81 | print('model_type=%s, config file will be replaced by %s' % |
| 82 | (args.model_type, CONFIG_TEMPLATE_ZOO[args.model_type])) |
| 83 | args.config = CONFIG_TEMPLATE_ZOO[args.model_type] |
| 84 | |
| 85 | if args.config.startswith('http'): |
| 86 | |
| 87 | r = requests.get(args.config) |
| 88 | # download config in current dir |
| 89 | tpath = args.config.split('/')[-1] |
| 90 | while not osp.exists(tpath): |
| 91 | try: |
| 92 | with open(tpath, 'wb') as code: |
| 93 | code.write(r.content) |
| 94 | except: |
| 95 | pass |
| 96 | |
| 97 | args.config = tpath |
| 98 | |
| 99 | cfg = mmcv_config_fromfile(args.config) |
| 100 | |
| 101 | if args.user_config_params is not None: |
| 102 | assert args.model_type is not None, 'model_type must be setted' |
| 103 | # rebuild config by user config params |
| 104 | cfg = rebuild_config(cfg, args.user_config_params) |
| 105 | |
| 106 | # check oss_config and init oss io |
| 107 | if cfg.get('oss_io_config', None) is not None: |
| 108 | io.access_oss(**cfg.oss_io_config) |
| 109 | |
| 110 | # set cudnn_benchmark |
| 111 | if cfg.get('cudnn_benchmark', False): |
| 112 | torch.backends.cudnn.benchmark = True |
| 113 | |
| 114 | # update configs according to CLI args |
| 115 | if args.work_dir is not None: |
| 116 | cfg.work_dir = args.work_dir |
| 117 | |
| 118 | # if `work_dir` is oss path, redirect `work_dir` to local path, add `oss_work_dir` point to oss path, |
| 119 | # and use osssync hook to upload log and ckpt in work_dir to oss_work_dir |
| 120 | if cfg.work_dir.startswith('oss://'): |
| 121 | cfg.oss_work_dir = cfg.work_dir |
| 122 | cfg.work_dir = osp.join('work_dirs', |
| 123 | cfg.work_dir.replace('oss://', '')) |
| 124 | else: |
| 125 | cfg.oss_work_dir = None |
| 126 | |
| 127 | # create work_dir |
| 128 | if not io.exists(cfg.work_dir): |
| 129 | io.makedirs(cfg.work_dir) |
| 130 | |
| 131 | # init the logger before other steps |
| 132 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) |
no test coverage detected