MCPcopy
hub / github.com/DLR-RM/stable-baselines3 / VecNormalize

Class VecNormalize

stable_baselines3/common/vec_env/vec_normalize.py:15–332  ·  view source on GitHub ↗

A moving average, normalizing wrapper for vectorized environment. has support for saving/loading moving average, :param venv: the vectorized environment to wrap :param training: Whether to update or not the moving average :param norm_obs: Whether to normalize observation or not

Source from the content-addressed store, hash-verified

13
14
15class VecNormalize(VecEnvWrapper):
16 """
17 A moving average, normalizing wrapper for vectorized environment.
18 has support for saving/loading moving average,
19
20 :param venv: the vectorized environment to wrap
21 :param training: Whether to update or not the moving average
22 :param norm_obs: Whether to normalize observation or not (default: True)
23 :param norm_reward: Whether to normalize rewards or not (default: True)
24 :param clip_obs: Max absolute value for observation
25 :param clip_reward: Max value absolute for discounted reward
26 :param gamma: discount factor
27 :param epsilon: To avoid division by zero
28 :param norm_obs_keys: Which keys from observation dict to normalize.
29 If not specified, all keys will be normalized.
30 """
31
32 obs_spaces: dict[str, spaces.Space]
33 old_obs: np.ndarray | dict[str, np.ndarray]
34
35 def __init__(
36 self,
37 venv: VecEnv,
38 training: bool = True,
39 norm_obs: bool = True,
40 norm_reward: bool = True,
41 clip_obs: float = 10.0,
42 clip_reward: float = 10.0,
43 gamma: float = 0.99,
44 epsilon: float = 1e-8,
45 norm_obs_keys: list[str] | None = None,
46 ):
47 VecEnvWrapper.__init__(self, venv)
48
49 self.norm_obs = norm_obs
50 self.norm_obs_keys = norm_obs_keys
51 # Check observation spaces
52 if self.norm_obs:
53 # Note: mypy doesn't take into account the sanity checks, which lead to several type: ignore...
54 self._sanity_checks()
55
56 if isinstance(self.observation_space, spaces.Dict):
57 self.obs_spaces = self.observation_space.spaces
58 self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys} # type: ignore[arg-type, union-attr]
59 # Update observation space when using image
60 # See explanation below and GH #1214
61 for key in self.obs_rms.keys():
62 if is_image_space(self.obs_spaces[key]):
63 self.observation_space.spaces[key] = spaces.Box(
64 low=-clip_obs,
65 high=clip_obs,
66 shape=self.obs_spaces[key].shape,
67 dtype=np.float32,
68 )
69
70 else:
71 self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) # type: ignore[assignment, arg-type]
72 # Update observation space when using image

Callers 15

test_vec_deterministicFunction · 0.90
test_renderFunction · 0.90
test_video_recorderFunction · 0.90
test_vec_normalizeFunction · 0.90
test_vec_normalize_imageFunction · 0.90
test_vec_monitor_warnFunction · 0.90
_make_warmstartFunction · 0.90
test_vec_envFunction · 0.90
test_her_normalizationFunction · 0.90
test_sync_vec_normalizeFunction · 0.90

Calls

no outgoing calls

Tested by 15

test_vec_deterministicFunction · 0.72
test_renderFunction · 0.72
test_video_recorderFunction · 0.72
test_vec_normalizeFunction · 0.72
test_vec_normalize_imageFunction · 0.72
test_vec_monitor_warnFunction · 0.72
_make_warmstartFunction · 0.72
test_vec_envFunction · 0.72
test_her_normalizationFunction · 0.72
test_sync_vec_normalizeFunction · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…