(config, args)
| 199 | |
| 200 | |
| 201 | def update_config(config, args): |
| 202 | _update_config_from_file(config, args.cfg) |
| 203 | |
| 204 | config.defrost() |
| 205 | if args.opts: |
| 206 | config.merge_from_list(args.opts) |
| 207 | |
| 208 | # merge from specific arguments |
| 209 | if args.batch_size: |
| 210 | config.DATA.BATCH_SIZE = args.batch_size |
| 211 | if args.data_path: |
| 212 | config.DATA.DATA_PATH = args.data_path |
| 213 | if args.pretrained: |
| 214 | config.MODEL.PRETRAINED = args.pretrained |
| 215 | if args.resume: |
| 216 | config.MODEL.RESUME = args.resume |
| 217 | if args.accumulation_steps: |
| 218 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps |
| 219 | if args.use_checkpoint: |
| 220 | config.TRAIN.USE_CHECKPOINT = True |
| 221 | if args.disable_amp or args.only_cpu: |
| 222 | config.AMP_ENABLE = False |
| 223 | if args.output: |
| 224 | config.OUTPUT = args.output |
| 225 | if args.tag: |
| 226 | config.TAG = args.tag |
| 227 | if args.eval: |
| 228 | config.EVAL_MODE = True |
| 229 | if args.throughput: |
| 230 | config.THROUGHPUT_MODE = True |
| 231 | |
| 232 | # set local rank for distributed training |
| 233 | if args.local_rank is None and 'LOCAL_RANK' in os.environ: |
| 234 | args.local_rank = int(os.environ['LOCAL_RANK']) |
| 235 | # set local rank for distributed training |
| 236 | config.LOCAL_RANK = args.local_rank |
| 237 | |
| 238 | # output folder |
| 239 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG) |
| 240 | |
| 241 | config.freeze() |
| 242 | |
| 243 | |
| 244 | def get_config(args=None): |
no test coverage detected