| 276 | """ |
| 277 | |
| 278 | def __init__(self, param_distributions, n_iter, *, random_state=None): |
| 279 | if not isinstance(param_distributions, (Mapping, Iterable)): |
| 280 | raise TypeError( |
| 281 | "Parameter distribution is not a dict or a list," |
| 282 | f" got: {param_distributions!r} of type " |
| 283 | f"{type(param_distributions).__name__}" |
| 284 | ) |
| 285 | |
| 286 | if isinstance(param_distributions, Mapping): |
| 287 | # wrap dictionary in a singleton list to support either dict |
| 288 | # or list of dicts |
| 289 | param_distributions = [param_distributions] |
| 290 | |
| 291 | for dist in param_distributions: |
| 292 | if not isinstance(dist, dict): |
| 293 | raise TypeError( |
| 294 | "Parameter distribution is not a dict ({!r})".format(dist) |
| 295 | ) |
| 296 | for key in dist: |
| 297 | if not isinstance(dist[key], Iterable) and not hasattr( |
| 298 | dist[key], "rvs" |
| 299 | ): |
| 300 | raise TypeError( |
| 301 | f"Parameter grid for parameter {key!r} is not iterable " |
| 302 | f"or a distribution (value={dist[key]})" |
| 303 | ) |
| 304 | self.n_iter = n_iter |
| 305 | self.random_state = random_state |
| 306 | self.param_distributions = param_distributions |
| 307 | |
| 308 | def _is_all_lists(self): |
| 309 | return all( |