MCPcopy
hub / github.com/ddbourgin/numpy-ml / OptimizerInitializer

Class OptimizerInitializer

numpy_ml/neural_nets/initializers/initializers.py:176–239  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

174
175
176class OptimizerInitializer(object):
177 def __init__(self, param=None):
178 """
179 A class for initializing optimizers. Valid `param` values are:
180 (a) __str__ representations of `OptimizerBase` instances
181 (b) `OptimizerBase` instances
182 (c) Parameter dicts (e.g., as produced via the `summary` method in
183 `LayerBase` instances)
184
185 If `param` is `None`, return the SGD optimizer with default parameters.
186 """
187 self.param = param
188
189 def __call__(self):
190 """Initialize the optimizer"""
191 param = self.param
192 if param is None:
193 opt = SGD()
194 elif isinstance(param, OptimizerBase):
195 opt = param
196 elif isinstance(param, str):
197 opt = self.init_from_str()
198 elif isinstance(param, dict):
199 opt = self.init_from_dict()
200 return opt
201
202 def init_from_str(self):
203 """Initialize optimizer from the `param` string"""
204 r = r"([a-zA-Z]*)=([^,)]*)"
205 opt_str = self.param.lower()
206 kwargs = {i: _eval(j) for i, j in re.findall(r, opt_str)}
207 if "sgd" in opt_str:
208 optimizer = SGD(**kwargs)
209 elif "adagrad" in opt_str:
210 optimizer = AdaGrad(**kwargs)
211 elif "rmsprop" in opt_str:
212 optimizer = RMSProp(**kwargs)
213 elif "adam" in opt_str:
214 optimizer = Adam(**kwargs)
215 else:
216 raise NotImplementedError("{}".format(opt_str))
217 return optimizer
218
219 def init_from_dict(self):
220 """Initialize optimizer from the `param` dictonary"""
221 D = self.param
222 cc = D["cache"] if "cache" in D else None
223 op = D["hyperparameters"] if "hyperparameters" in D else None
224
225 if op is None:
226 raise ValueError("`param` dictionary has no `hyperparemeters` key")
227
228 if op and op["id"] == "SGD":
229 optimizer = SGD()
230 elif op and op["id"] == "RMSProp":
231 optimizer = RMSProp()
232 elif op and op["id"] == "AdaGrad":
233 optimizer = AdaGrad()

Callers 3

__init__Method · 0.85
set_paramsMethod · 0.85
__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected