(replay_buffer_cls)
| 80 | |
| 81 | @pytest.mark.parametrize("replay_buffer_cls", [ReplayBuffer, DictReplayBuffer]) |
| 82 | def test_replay_buffer_normalization(replay_buffer_cls): |
| 83 | env = {ReplayBuffer: DummyEnv, DictReplayBuffer: DummyDictEnv}[replay_buffer_cls] |
| 84 | env = make_vec_env(env) |
| 85 | env = VecNormalize(env) |
| 86 | |
| 87 | buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device="cpu") |
| 88 | |
| 89 | # Interact and store transitions |
| 90 | env.reset() |
| 91 | obs = env.get_original_obs() |
| 92 | for _ in range(100): |
| 93 | action = env.action_space.sample() |
| 94 | _, _, done, info = env.step(action) |
| 95 | next_obs = env.get_original_obs() |
| 96 | reward = env.get_original_reward() |
| 97 | buffer.add(obs, next_obs, action, reward, done, info) |
| 98 | obs = next_obs |
| 99 | |
| 100 | sample = buffer.sample(50, env) |
| 101 | # Test observation normalization |
| 102 | for observations in [sample.observations, sample.next_observations]: |
| 103 | if isinstance(sample, DictReplayBufferSamples): |
| 104 | for key in observations.keys(): |
| 105 | assert th.allclose(observations[key].mean(0), th.zeros(1), atol=1) |
| 106 | elif isinstance(sample, ReplayBufferSamples): |
| 107 | assert th.allclose(observations.mean(0), th.zeros(1), atol=1) |
| 108 | # Test reward normalization |
| 109 | assert np.allclose(sample.rewards.mean(0), np.zeros(1), atol=1) |
| 110 | |
| 111 | |
| 112 | @pytest.mark.parametrize("replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer]) |
nothing calls this directly
no test coverage detected
searching dependent graphs…