| 174 | |
| 175 | |
| 176 | class OptimizerInitializer(object): |
| 177 | def __init__(self, param=None): |
| 178 | """ |
| 179 | A class for initializing optimizers. Valid `param` values are: |
| 180 | (a) __str__ representations of `OptimizerBase` instances |
| 181 | (b) `OptimizerBase` instances |
| 182 | (c) Parameter dicts (e.g., as produced via the `summary` method in |
| 183 | `LayerBase` instances) |
| 184 | |
| 185 | If `param` is `None`, return the SGD optimizer with default parameters. |
| 186 | """ |
| 187 | self.param = param |
| 188 | |
| 189 | def __call__(self): |
| 190 | """Initialize the optimizer""" |
| 191 | param = self.param |
| 192 | if param is None: |
| 193 | opt = SGD() |
| 194 | elif isinstance(param, OptimizerBase): |
| 195 | opt = param |
| 196 | elif isinstance(param, str): |
| 197 | opt = self.init_from_str() |
| 198 | elif isinstance(param, dict): |
| 199 | opt = self.init_from_dict() |
| 200 | return opt |
| 201 | |
| 202 | def init_from_str(self): |
| 203 | """Initialize optimizer from the `param` string""" |
| 204 | r = r"([a-zA-Z]*)=([^,)]*)" |
| 205 | opt_str = self.param.lower() |
| 206 | kwargs = {i: _eval(j) for i, j in re.findall(r, opt_str)} |
| 207 | if "sgd" in opt_str: |
| 208 | optimizer = SGD(**kwargs) |
| 209 | elif "adagrad" in opt_str: |
| 210 | optimizer = AdaGrad(**kwargs) |
| 211 | elif "rmsprop" in opt_str: |
| 212 | optimizer = RMSProp(**kwargs) |
| 213 | elif "adam" in opt_str: |
| 214 | optimizer = Adam(**kwargs) |
| 215 | else: |
| 216 | raise NotImplementedError("{}".format(opt_str)) |
| 217 | return optimizer |
| 218 | |
| 219 | def init_from_dict(self): |
| 220 | """Initialize optimizer from the `param` dictonary""" |
| 221 | D = self.param |
| 222 | cc = D["cache"] if "cache" in D else None |
| 223 | op = D["hyperparameters"] if "hyperparameters" in D else None |
| 224 | |
| 225 | if op is None: |
| 226 | raise ValueError("`param` dictionary has no `hyperparemeters` key") |
| 227 | |
| 228 | if op and op["id"] == "SGD": |
| 229 | optimizer = SGD() |
| 230 | elif op and op["id"] == "RMSProp": |
| 231 | optimizer = RMSProp() |
| 232 | elif op and op["id"] == "AdaGrad": |
| 233 | optimizer = AdaGrad() |
no outgoing calls
no test coverage detected