Check that the input will be compatible with Stable-Baselines when the observation is apparently an image. :param observation_space: Observation space :param key: When the observation space comes from a Dict space, we pass the corresponding key to have more precise warning
(observation_space: spaces.Box, key: str = "")
| 55 | |
| 56 | |
| 57 | def _check_image_input(observation_space: spaces.Box, key: str = "") -> None: |
| 58 | """ |
| 59 | Check that the input will be compatible with Stable-Baselines |
| 60 | when the observation is apparently an image. |
| 61 | |
| 62 | :param observation_space: Observation space |
| 63 | :param key: When the observation space comes from a Dict space, we pass the |
| 64 | corresponding key to have more precise warning messages. Defaults to "". |
| 65 | """ |
| 66 | if observation_space.dtype != np.uint8: |
| 67 | warnings.warn( |
| 68 | f"It seems that your observation {key} is an image but its `dtype` " |
| 69 | f"is ({observation_space.dtype}) whereas it has to be `np.uint8`. " |
| 70 | "If your observation is not an image, we recommend you to flatten the observation " |
| 71 | "to have only a 1D vector" |
| 72 | ) |
| 73 | |
| 74 | if np.any(observation_space.low != 0) or np.any(observation_space.high != 255): |
| 75 | warnings.warn( |
| 76 | f"It seems that your observation space {key} is an image but the " |
| 77 | "upper and lower bounds are not in [0, 255]. " |
| 78 | "Because the CNN policy normalize automatically the observation " |
| 79 | "you may encounter issue if the values are not in that range." |
| 80 | ) |
| 81 | |
| 82 | non_channel_idx = 0 |
| 83 | # Check only if width/height of the image is big enough |
| 84 | if is_image_space_channels_first(observation_space): |
| 85 | non_channel_idx = -1 |
| 86 | |
| 87 | if observation_space.shape[non_channel_idx] < 36 or observation_space.shape[1] < 36: |
| 88 | warnings.warn( |
| 89 | "The minimal resolution for an image is 36x36 for the default `CnnPolicy`. " |
| 90 | "You might need to use a custom features extractor " |
| 91 | "cf. https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html" |
| 92 | ) |
| 93 | |
| 94 | |
| 95 | def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space) -> bool: # noqa: C901 |
no test coverage detected
searching dependent graphs…