(config, args)
| 281 | |
| 282 | |
| 283 | def update_config(config, args): |
| 284 | _update_config_from_file(config, args.cfg) |
| 285 | |
| 286 | config.defrost() |
| 287 | if args.opts: |
| 288 | config.merge_from_list(args.opts) |
| 289 | |
| 290 | def _check_args(name): |
| 291 | if hasattr(args, name) and eval(f'args.{name}'): |
| 292 | return True |
| 293 | return False |
| 294 | |
| 295 | # merge from specific arguments |
| 296 | if _check_args('batch_size'): |
| 297 | config.DATA.BATCH_SIZE = args.batch_size |
| 298 | if _check_args('data_path'): |
| 299 | config.DATA.DATA_PATH = args.data_path |
| 300 | if _check_args('zip'): |
| 301 | config.DATA.ZIP_MODE = True |
| 302 | if _check_args('cache_mode'): |
| 303 | config.DATA.CACHE_MODE = args.cache_mode |
| 304 | if _check_args('pretrained'): |
| 305 | config.MODEL.PRETRAINED = args.pretrained |
| 306 | if _check_args('resume'): |
| 307 | config.MODEL.RESUME = args.resume |
| 308 | if _check_args('accumulation_steps'): |
| 309 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps |
| 310 | if _check_args('use_checkpoint'): |
| 311 | config.TRAIN.USE_CHECKPOINT = True |
| 312 | if _check_args('amp_opt_level'): |
| 313 | print("[warning] Apex amp has been deprecated, please use pytorch amp instead!") |
| 314 | if args.amp_opt_level == 'O0': |
| 315 | config.AMP_ENABLE = False |
| 316 | if _check_args('disable_amp'): |
| 317 | config.AMP_ENABLE = False |
| 318 | if _check_args('output'): |
| 319 | config.OUTPUT = args.output |
| 320 | if _check_args('tag'): |
| 321 | config.TAG = args.tag |
| 322 | if _check_args('eval'): |
| 323 | config.EVAL_MODE = True |
| 324 | if _check_args('throughput'): |
| 325 | config.THROUGHPUT_MODE = True |
| 326 | |
| 327 | # [SimMIM] |
| 328 | if _check_args('enable_amp'): |
| 329 | config.ENABLE_AMP = args.enable_amp |
| 330 | |
| 331 | # for acceleration |
| 332 | if _check_args('fused_window_process'): |
| 333 | config.FUSED_WINDOW_PROCESS = True |
| 334 | if _check_args('fused_layernorm'): |
| 335 | config.FUSED_LAYERNORM = True |
| 336 | ## Overwrite optimizer if not None, currently we use it for [fused_adam, fused_lamb] |
| 337 | if _check_args('optim'): |
| 338 | config.TRAIN.OPTIMIZER.NAME = args.optim |
| 339 | |
| 340 | # set local rank for distributed training |
no test coverage detected