MCPcopy
hub / github.com/microsoft/Swin-Transformer / build_optimizer

Function build_optimizer

optimizer.py:19–56  ·  view source on GitHub ↗

Build optimizer, set weight decay of normalization to 0 by default.

(config, model, simmim=False, is_pretrain=False)

Source from the content-addressed store, hash-verified

17
18
19def build_optimizer(config, model, simmim=False, is_pretrain=False):
20 """
21 Build optimizer, set weight decay of normalization to 0 by default.
22 """
23 skip = {}
24 skip_keywords = {}
25 if hasattr(model, 'no_weight_decay'):
26 skip = model.no_weight_decay()
27 if hasattr(model, 'no_weight_decay_keywords'):
28 skip_keywords = model.no_weight_decay_keywords()
29 if simmim:
30 if is_pretrain:
31 parameters = get_pretrain_param_groups(model, skip, skip_keywords)
32 else:
33 depths = config.MODEL.SWIN.DEPTHS if config.MODEL.TYPE == 'swin' else config.MODEL.SWINV2.DEPTHS
34 num_layers = sum(depths)
35 get_layer_func = partial(get_swin_layer, num_layers=num_layers + 2, depths=depths)
36 scales = list(config.TRAIN.LAYER_DECAY ** i for i in reversed(range(num_layers + 2)))
37 parameters = get_finetune_param_groups(model, config.TRAIN.BASE_LR, config.TRAIN.WEIGHT_DECAY, get_layer_func, scales, skip, skip_keywords)
38 else:
39 parameters = set_weight_decay(model, skip, skip_keywords)
40
41 opt_lower = config.TRAIN.OPTIMIZER.NAME.lower()
42 optimizer = None
43 if opt_lower == 'sgd':
44 optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
45 lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
46 elif opt_lower == 'adamw':
47 optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
48 lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
49 elif opt_lower == 'fused_adam':
50 optimizer = FusedAdam(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
51 lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
52 elif opt_lower == 'fused_lamb':
53 optimizer = FusedLAMB(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
54 lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
55
56 return optimizer
57
58
59def set_weight_decay(model, skip_list=(), skip_keywords=()):

Callers 4

mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90

Calls 5

set_weight_decayFunction · 0.85
no_weight_decayMethod · 0.45

Tested by

no test coverage detected