MCPcopy
hub / github.com/alibaba/EasyCV / main

Function main

tools/train.py:116–315  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

114
115
116def main():
117 args = parse_args()
118
119 if args.model_type is not None:
120 assert args.model_type in CONFIG_TEMPLATE_ZOO, 'model_type must be in [%s]' % (
121 ', '.join(CONFIG_TEMPLATE_ZOO.keys()))
122 print('model_type=%s, config file will be replaced by %s' %
123 (args.model_type, CONFIG_TEMPLATE_ZOO[args.model_type]))
124 args.config = CONFIG_TEMPLATE_ZOO[args.model_type]
125
126 if args.config.startswith('http'):
127
128 r = requests.get(args.config)
129 # download config in current dir
130 tpath = args.config.split('/')[-1]
131 while not osp.exists(tpath):
132 try:
133 with open(tpath, 'wb') as code:
134 code.write(r.content)
135 except:
136 pass
137
138 args.config = tpath
139
140 # build cfg
141 if args.user_config_params is None:
142 cfg = mmcv_config_fromfile(args.config)
143 else:
144 cfg = pai_config_fromfile(args.config, args.user_config_params,
145 args.model_type)
146
147 # set multi-process settings
148 setup_multi_processes(cfg)
149
150 # set cudnn_benchmark
151 if cfg.get('cudnn_benchmark', False):
152 torch.backends.cudnn.benchmark = True
153
154 # update configs according to CLI args
155 if args.work_dir is not None:
156 cfg.work_dir = args.work_dir
157
158 if cfg.get('work_dir', None) is None:
159 cfg.work_dir = './work_dir'
160
161 # if `work_dir` is oss path, redirect `work_dir` to local path, add `oss_work_dir` point to oss path,
162 # and use osssync hook to upload log and ckpt in work_dir to oss_work_dir
163 if cfg.work_dir.startswith('oss://'):
164 cfg.oss_work_dir = cfg.work_dir
165 cfg.work_dir = osp.join('work_dirs',
166 cfg.work_dir.replace('oss://', ''))
167 else:
168 cfg.oss_work_dir = None
169
170 if args.resume_from is not None and len(args.resume_from) > 0:
171 cfg.resume_from = args.resume_from
172 if args.load_from is not None and len(args.load_from) > 0:
173 cfg.load_from = args.load_from

Callers 1

train.pyFile · 0.70

Calls 15

mmcv_config_fromfileFunction · 0.90
pai_config_fromfileFunction · 0.90
setup_multi_processesFunction · 0.90
traverse_replaceFunction · 0.90
is_torchacc_enabledFunction · 0.90
get_root_loggerFunction · 0.90
collect_envFunction · 0.90
init_random_seedFunction · 0.90
get_deviceFunction · 0.90
set_random_seedFunction · 0.90
build_modelFunction · 0.90
is_masterFunction · 0.90

Tested by

no test coverage detected