MCPcopy
hub / github.com/microsoft/Cream / create_optimizer_supernet

Function create_optimizer_supernet

Cream/lib/utils/util.py:96–132  ·  view source on GitHub ↗
(args, model, has_apex, filter_bias_and_bn=True)

Source from the content-addressed store, hash-verified

94
95
96def create_optimizer_supernet(args, model, has_apex, filter_bias_and_bn=True):
97 opt_lower = args.opt.lower()
98 weight_decay = args.weight_decay
99 if 'adamw' in opt_lower or 'radam' in opt_lower:
100 weight_decay /= args.lr
101 if weight_decay and filter_bias_and_bn:
102 parameters = add_weight_decay_supernet(model, args, weight_decay)
103 weight_decay = 0.
104 else:
105 parameters = model.parameters()
106
107 if 'fused' in opt_lower:
108 assert has_apex and torch.cuda.is_available(
109 ), 'APEX and CUDA required for fused optimizers'
110
111 opt_split = opt_lower.split('_')
112 opt_lower = opt_split[-1]
113 if opt_lower == 'sgd' or opt_lower == 'nesterov':
114 optimizer = optim.SGD(
115 parameters,
116 momentum=args.momentum,
117 weight_decay=weight_decay,
118 nesterov=True)
119 elif opt_lower == 'momentum':
120 optimizer = optim.SGD(
121 parameters,
122 momentum=args.momentum,
123 weight_decay=weight_decay,
124 nesterov=False)
125 elif opt_lower == 'adam':
126 optimizer = optim.Adam(
127 parameters, weight_decay=weight_decay, eps=args.opt_eps)
128 else:
129 assert False and "Invalid optimizer"
130 raise ValueError
131
132 return optimizer
133
134
135def convert_lowercase(cfg):

Callers 1

mainFunction · 0.90

Calls 1

Tested by

no test coverage detected