(model_class)
| 398 | |
| 399 | @pytest.mark.parametrize("model_class", [SAC, TD3]) |
| 400 | def test_offpolicy_normalization(model_class): |
| 401 | env = DummyVecEnv([make_env]) |
| 402 | env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0, clip_reward=10.0) |
| 403 | |
| 404 | eval_env = DummyVecEnv([make_env]) |
| 405 | eval_env = VecNormalize(eval_env, training=False, norm_obs=True, norm_reward=False, clip_obs=10.0, clip_reward=10.0) |
| 406 | |
| 407 | model = model_class("MlpPolicy", env, verbose=1, learning_starts=100, policy_kwargs=dict(net_arch=[64])) |
| 408 | |
| 409 | # Check that VecNormalize object is correctly updated |
| 410 | assert model.get_vec_normalize_env() is env |
| 411 | model.set_env(eval_env) |
| 412 | assert model.get_vec_normalize_env() is eval_env |
| 413 | model.learn(total_timesteps=10) |
| 414 | model.set_env(env) |
| 415 | model.learn(total_timesteps=150) |
| 416 | # Check getter |
| 417 | assert isinstance(model.get_vec_normalize_env(), VecNormalize) |
| 418 | |
| 419 | |
| 420 | @pytest.mark.parametrize("make_env", [make_env, make_dict_env]) |
nothing calls this directly
no test coverage detected
searching dependent graphs…