Get information on the observation space for `env`. Parameters ---------- env : ``gym.wrappers`` or ``gym.envs`` instance The environment to evaluate. md_obs : bool Whether the `env`'s action space is multidimensional. cont_obs : bool Whether the `en
(env, md_obs, cont_obs)
| 380 | |
| 381 | |
| 382 | def obs_stats(env, md_obs, cont_obs): |
| 383 | """ |
| 384 | Get information on the observation space for `env`. |
| 385 | |
| 386 | Parameters |
| 387 | ---------- |
| 388 | env : ``gym.wrappers`` or ``gym.envs`` instance |
| 389 | The environment to evaluate. |
| 390 | md_obs : bool |
| 391 | Whether the `env`'s action space is multidimensional. |
| 392 | cont_obs : bool |
| 393 | Whether the `env`'s observation space is multidimensional. |
| 394 | |
| 395 | Returns |
| 396 | ------- |
| 397 | n_obs_per_dim : list of length (obs_dim,) |
| 398 | The number of possible observation classes for each dimension of the |
| 399 | observation space. |
| 400 | obs_ids : list or None |
| 401 | A list of all valid observations within the space. If `cont_obs` is |
| 402 | True, this value will be None. |
| 403 | obs_dim : int or None |
| 404 | The number of dimensions in a single observation. |
| 405 | """ |
| 406 | if cont_obs: |
| 407 | obs_ids = None |
| 408 | obs_dim = env.observation_space.shape[0] |
| 409 | n_obs_per_dim = [np.inf for _ in range(obs_dim)] |
| 410 | else: |
| 411 | if md_obs: |
| 412 | n_obs_per_dim = [ |
| 413 | space.n if hasattr(space, "n") else np.inf |
| 414 | for space in env.observation_space.spaces |
| 415 | ] |
| 416 | obs_ids = ( |
| 417 | None |
| 418 | if np.inf in n_obs_per_dim |
| 419 | else list(product(*[range(i) for i in n_obs_per_dim])) |
| 420 | ) |
| 421 | obs_dim = len(n_obs_per_dim) |
| 422 | else: |
| 423 | obs_dim = 1 |
| 424 | n_obs_per_dim = [env.observation_space.n] |
| 425 | obs_ids = list(range(n_obs_per_dim[0])) |
| 426 | |
| 427 | return n_obs_per_dim, obs_ids, obs_dim |
| 428 | |
| 429 | |
| 430 | def env_stats(env): |