MCPcopy
hub / github.com/alibaba/EasyCV / build_optimizer

Function build_optimizer

easycv/apis/train.py:350–462  ·  view source on GitHub ↗

Build optimizer from configs. Args: model (:obj:`nn.Module`): The model with parameters to be optimized. optimizer_cfg (dict): The config dict of the optimizer. Positional fields are: - type: class name of the optimizer. - lr: base le

(model, optimizer_cfg)

Source from the content-addressed store, hash-verified

348
349
350def build_optimizer(model, optimizer_cfg):
351 """Build optimizer from configs.
352
353 Args:
354 model (:obj:`nn.Module`): The model with parameters to be optimized.
355 optimizer_cfg (dict): The config dict of the optimizer.
356
357 Positional fields are:
358 - type: class name of the optimizer.
359 - lr: base learning rate.
360
361 Optional fields are:
362 - any arguments of the corresponding optimizer type, e.g.,
363 weight_decay, momentum, etc.
364 - paramwise_options: a dict with regular expression as keys
365 to match parameter names and a dict containing options as
366 values. Options include 6 fields: lr, lr_mult, momentum,
367 momentum_mult, weight_decay, weight_decay_mult.
368
369 Returns:
370 torch.optim.Optimizer: The initialized optimizer.
371
372 Example:
373 >>> model = torch.nn.modules.Conv1d(1, 1, 1)
374 >>> paramwise_options = {
375 >>> '(bn|gn)(\d+)?.(weight|bias)': dict(weight_decay_mult=0.1),
376 >>> '\Ahead.': dict(lr_mult=10, momentum=0)}
377 >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
378 >>> weight_decay=0.0001,
379 >>> paramwise_options=paramwise_options)
380 >>> optimizer = build_optimizer(model, optimizer_cfg)
381 """
382
383 if hasattr(model, 'module'):
384 model = model.module
385
386 # some special model (DINO) only need to optimize parts of parameter, this kind of model will
387 # provide attribute get_params_groups to initial optimizer, as we catch this attribute, we do this if
388 if hasattr(model, 'get_params_groups'):
389 print('type : ', type(model),
390 'trigger opimizer model param_groups set for DINO')
391 parameters = model.get_params_groups()
392 optimizer_cfg = optimizer_cfg.copy()
393 optimizer_cls = getattr(optimizer, optimizer_cfg.pop('type'))
394 return optimizer_cls(parameters, **optimizer_cfg)
395
396 # for some model which use transformer(swin/shuffle/cswin), we should set it bias with no weight decay
397 set_var_bias_nowd = optimizer_cfg.pop('set_var_bias_nowd', None)
398 if set_var_bias_nowd is None:
399 set_var_bias_nowd = optimizer_cfg.pop(
400 'trans_weight_decay_set', None
401 ) # this is failback when we switch version, set_var_bias_nowd used called trans_weight_decay_set
402 if set_var_bias_nowd is not None:
403 print('type : ', type(model), 'trigger transformer set_var_bias_nowd')
404 skip = []
405 skip_keywords = []
406 assert (type(set_var_bias_nowd) is list)
407 for model_part in set_var_bias_nowd:

Callers 3

mainFunction · 0.90
train_modelFunction · 0.85

Calls 7

print_logFunction · 0.90
get_skip_list_keywordsFunction · 0.85
_set_weight_decayFunction · 0.85
get_params_groupsMethod · 0.80
add_paramsMethod · 0.80
copyMethod · 0.45

Tested by

no test coverage detected