MCPcopy
hub / github.com/DLR-RM/stable-baselines3 / test_replay_buffer_normalization

Function test_replay_buffer_normalization

tests/test_buffers.py:82–109  ·  view source on GitHub ↗
(replay_buffer_cls)

Source from the content-addressed store, hash-verified

80
81@pytest.mark.parametrize("replay_buffer_cls", [ReplayBuffer, DictReplayBuffer])
82def test_replay_buffer_normalization(replay_buffer_cls):
83 env = {ReplayBuffer: DummyEnv, DictReplayBuffer: DummyDictEnv}[replay_buffer_cls]
84 env = make_vec_env(env)
85 env = VecNormalize(env)
86
87 buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device="cpu")
88
89 # Interact and store transitions
90 env.reset()
91 obs = env.get_original_obs()
92 for _ in range(100):
93 action = env.action_space.sample()
94 _, _, done, info = env.step(action)
95 next_obs = env.get_original_obs()
96 reward = env.get_original_reward()
97 buffer.add(obs, next_obs, action, reward, done, info)
98 obs = next_obs
99
100 sample = buffer.sample(50, env)
101 # Test observation normalization
102 for observations in [sample.observations, sample.next_observations]:
103 if isinstance(sample, DictReplayBufferSamples):
104 for key in observations.keys():
105 assert th.allclose(observations[key].mean(0), th.zeros(1), atol=1)
106 elif isinstance(sample, ReplayBufferSamples):
107 assert th.allclose(observations.mean(0), th.zeros(1), atol=1)
108 # Test reward normalization
109 assert np.allclose(sample.rewards.mean(0), np.zeros(1), atol=1)
110
111
112@pytest.mark.parametrize("replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer])

Callers

nothing calls this directly

Calls 8

resetMethod · 0.95
get_original_obsMethod · 0.95
get_original_rewardMethod · 0.95
make_vec_envFunction · 0.90
VecNormalizeClass · 0.90
sampleMethod · 0.45
stepMethod · 0.45
addMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…