| 239 | |
| 240 | |
| 241 | class KernelInitializer(object): |
| 242 | def __init__(self, param=None): |
| 243 | """ |
| 244 | A class for initializing learning rate schedulers. Valid inputs are: |
| 245 | (a) __str__ representations of `KernelBase` instances |
| 246 | (b) `KernelBase` instances |
| 247 | (c) Parameter dicts (e.g., as produced via the :meth:`summary` method in |
| 248 | `KernelBase` instances) |
| 249 | |
| 250 | If `param` is None, return `LinearKernel`. |
| 251 | """ |
| 252 | self.param = param |
| 253 | |
| 254 | def __call__(self): |
| 255 | param = self.param |
| 256 | if param is None: |
| 257 | kernel = LinearKernel() |
| 258 | elif isinstance(param, KernelBase): |
| 259 | kernel = param |
| 260 | elif isinstance(param, str): |
| 261 | kernel = self.init_from_str() |
| 262 | elif isinstance(param, dict): |
| 263 | kernel = self.init_from_dict() |
| 264 | return kernel |
| 265 | |
| 266 | def init_from_str(self): |
| 267 | r = r"([a-zA-Z0-9]*)=([^,)]*)" |
| 268 | kr_str = self.param.lower() |
| 269 | kwargs = dict([(i, eval(j)) for (i, j) in re.findall(r, self.param)]) |
| 270 | |
| 271 | if "linear" in kr_str: |
| 272 | kernel = LinearKernel(**kwargs) |
| 273 | elif "polynomial" in kr_str: |
| 274 | kernel = PolynomialKernel(**kwargs) |
| 275 | elif "rbf" in kr_str: |
| 276 | kernel = RBFKernel(**kwargs) |
| 277 | else: |
| 278 | raise NotImplementedError("{}".format(kr_str)) |
| 279 | return kernel |
| 280 | |
| 281 | def init_from_dict(self): |
| 282 | S = self.param |
| 283 | sc = S["hyperparameters"] if "hyperparameters" in S else None |
| 284 | |
| 285 | if sc is None: |
| 286 | raise ValueError("Must have `hyperparameters` key: {}".format(S)) |
| 287 | |
| 288 | if sc and sc["id"] == "LinearKernel": |
| 289 | scheduler = LinearKernel().set_params(S) |
| 290 | elif sc and sc["id"] == "PolynomialKernel": |
| 291 | scheduler = PolynomialKernel().set_params(S) |
| 292 | elif sc and sc["id"] == "RBFKernel": |
| 293 | scheduler = RBFKernel().set_params(S) |
| 294 | elif sc: |
| 295 | raise NotImplementedError("{}".format(sc["id"])) |
| 296 | return scheduler |
| 297 | |
| 298 |
no outgoing calls
no test coverage detected