A moving average, normalizing wrapper for vectorized environment. has support for saving/loading moving average, :param venv: the vectorized environment to wrap :param training: Whether to update or not the moving average :param norm_obs: Whether to normalize observation or not
| 13 | |
| 14 | |
| 15 | class VecNormalize(VecEnvWrapper): |
| 16 | """ |
| 17 | A moving average, normalizing wrapper for vectorized environment. |
| 18 | has support for saving/loading moving average, |
| 19 | |
| 20 | :param venv: the vectorized environment to wrap |
| 21 | :param training: Whether to update or not the moving average |
| 22 | :param norm_obs: Whether to normalize observation or not (default: True) |
| 23 | :param norm_reward: Whether to normalize rewards or not (default: True) |
| 24 | :param clip_obs: Max absolute value for observation |
| 25 | :param clip_reward: Max value absolute for discounted reward |
| 26 | :param gamma: discount factor |
| 27 | :param epsilon: To avoid division by zero |
| 28 | :param norm_obs_keys: Which keys from observation dict to normalize. |
| 29 | If not specified, all keys will be normalized. |
| 30 | """ |
| 31 | |
| 32 | obs_spaces: dict[str, spaces.Space] |
| 33 | old_obs: np.ndarray | dict[str, np.ndarray] |
| 34 | |
| 35 | def __init__( |
| 36 | self, |
| 37 | venv: VecEnv, |
| 38 | training: bool = True, |
| 39 | norm_obs: bool = True, |
| 40 | norm_reward: bool = True, |
| 41 | clip_obs: float = 10.0, |
| 42 | clip_reward: float = 10.0, |
| 43 | gamma: float = 0.99, |
| 44 | epsilon: float = 1e-8, |
| 45 | norm_obs_keys: list[str] | None = None, |
| 46 | ): |
| 47 | VecEnvWrapper.__init__(self, venv) |
| 48 | |
| 49 | self.norm_obs = norm_obs |
| 50 | self.norm_obs_keys = norm_obs_keys |
| 51 | # Check observation spaces |
| 52 | if self.norm_obs: |
| 53 | # Note: mypy doesn't take into account the sanity checks, which lead to several type: ignore... |
| 54 | self._sanity_checks() |
| 55 | |
| 56 | if isinstance(self.observation_space, spaces.Dict): |
| 57 | self.obs_spaces = self.observation_space.spaces |
| 58 | self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys} # type: ignore[arg-type, union-attr] |
| 59 | # Update observation space when using image |
| 60 | # See explanation below and GH #1214 |
| 61 | for key in self.obs_rms.keys(): |
| 62 | if is_image_space(self.obs_spaces[key]): |
| 63 | self.observation_space.spaces[key] = spaces.Box( |
| 64 | low=-clip_obs, |
| 65 | high=clip_obs, |
| 66 | shape=self.obs_spaces[key].shape, |
| 67 | dtype=np.float32, |
| 68 | ) |
| 69 | |
| 70 | else: |
| 71 | self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) # type: ignore[assignment, arg-type] |
| 72 | # Update observation space when using image |
no outgoing calls
searching dependent graphs…