()
| 368 | |
| 369 | |
| 370 | def test_her_normalization(): |
| 371 | env = DummyVecEnv([make_dict_env]) |
| 372 | env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0, clip_reward=10.0) |
| 373 | |
| 374 | eval_env = DummyVecEnv([make_dict_env]) |
| 375 | eval_env = VecNormalize(eval_env, training=False, norm_obs=True, norm_reward=False, clip_obs=10.0, clip_reward=10.0) |
| 376 | |
| 377 | model = SAC( |
| 378 | "MultiInputPolicy", |
| 379 | env, |
| 380 | verbose=1, |
| 381 | learning_starts=100, |
| 382 | policy_kwargs=dict(net_arch=[64]), |
| 383 | replay_buffer_kwargs=dict(n_sampled_goal=2), |
| 384 | replay_buffer_class=HerReplayBuffer, |
| 385 | seed=2, |
| 386 | ) |
| 387 | |
| 388 | # Check that VecNormalize object is correctly updated |
| 389 | assert model.get_vec_normalize_env() is env |
| 390 | model.set_env(eval_env) |
| 391 | assert model.get_vec_normalize_env() is eval_env |
| 392 | model.learn(total_timesteps=10) |
| 393 | model.set_env(env) |
| 394 | model.learn(total_timesteps=150) |
| 395 | # Check getter |
| 396 | assert isinstance(model.get_vec_normalize_env(), VecNormalize) |
| 397 | |
| 398 | |
| 399 | @pytest.mark.parametrize("model_class", [SAC, TD3]) |
nothing calls this directly
no test coverage detected
searching dependent graphs…