Build optimizer from configs. Args: model (:obj:`nn.Module`): The model with parameters to be optimized. optimizer_cfg (dict): The config dict of the optimizer. Positional fields are: - type: class name of the optimizer. - lr: base le
(model, optimizer_cfg)
| 348 | |
| 349 | |
| 350 | def build_optimizer(model, optimizer_cfg): |
| 351 | """Build optimizer from configs. |
| 352 | |
| 353 | Args: |
| 354 | model (:obj:`nn.Module`): The model with parameters to be optimized. |
| 355 | optimizer_cfg (dict): The config dict of the optimizer. |
| 356 | |
| 357 | Positional fields are: |
| 358 | - type: class name of the optimizer. |
| 359 | - lr: base learning rate. |
| 360 | |
| 361 | Optional fields are: |
| 362 | - any arguments of the corresponding optimizer type, e.g., |
| 363 | weight_decay, momentum, etc. |
| 364 | - paramwise_options: a dict with regular expression as keys |
| 365 | to match parameter names and a dict containing options as |
| 366 | values. Options include 6 fields: lr, lr_mult, momentum, |
| 367 | momentum_mult, weight_decay, weight_decay_mult. |
| 368 | |
| 369 | Returns: |
| 370 | torch.optim.Optimizer: The initialized optimizer. |
| 371 | |
| 372 | Example: |
| 373 | >>> model = torch.nn.modules.Conv1d(1, 1, 1) |
| 374 | >>> paramwise_options = { |
| 375 | >>> '(bn|gn)(\d+)?.(weight|bias)': dict(weight_decay_mult=0.1), |
| 376 | >>> '\Ahead.': dict(lr_mult=10, momentum=0)} |
| 377 | >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9, |
| 378 | >>> weight_decay=0.0001, |
| 379 | >>> paramwise_options=paramwise_options) |
| 380 | >>> optimizer = build_optimizer(model, optimizer_cfg) |
| 381 | """ |
| 382 | |
| 383 | if hasattr(model, 'module'): |
| 384 | model = model.module |
| 385 | |
| 386 | # some special model (DINO) only need to optimize parts of parameter, this kind of model will |
| 387 | # provide attribute get_params_groups to initial optimizer, as we catch this attribute, we do this if |
| 388 | if hasattr(model, 'get_params_groups'): |
| 389 | print('type : ', type(model), |
| 390 | 'trigger opimizer model param_groups set for DINO') |
| 391 | parameters = model.get_params_groups() |
| 392 | optimizer_cfg = optimizer_cfg.copy() |
| 393 | optimizer_cls = getattr(optimizer, optimizer_cfg.pop('type')) |
| 394 | return optimizer_cls(parameters, **optimizer_cfg) |
| 395 | |
| 396 | # for some model which use transformer(swin/shuffle/cswin), we should set it bias with no weight decay |
| 397 | set_var_bias_nowd = optimizer_cfg.pop('set_var_bias_nowd', None) |
| 398 | if set_var_bias_nowd is None: |
| 399 | set_var_bias_nowd = optimizer_cfg.pop( |
| 400 | 'trans_weight_decay_set', None |
| 401 | ) # this is failback when we switch version, set_var_bias_nowd used called trans_weight_decay_set |
| 402 | if set_var_bias_nowd is not None: |
| 403 | print('type : ', type(model), 'trigger transformer set_var_bias_nowd') |
| 404 | skip = [] |
| 405 | skip_keywords = [] |
| 406 | assert (type(set_var_bias_nowd) is list) |
| 407 | for model_part in set_var_bias_nowd: |
no test coverage detected