| 116 | |
| 117 | @pytest.mark.parametrize("train_freq", [4, (4, "step"), (1, "episode")]) |
| 118 | def test_train_freq(tmp_path, train_freq): |
| 119 | model = SAC( |
| 120 | "MlpPolicy", |
| 121 | "Pendulum-v1", |
| 122 | policy_kwargs=dict(net_arch=[64, 64], n_critics=1), |
| 123 | learning_starts=100, |
| 124 | buffer_size=10000, |
| 125 | verbose=1, |
| 126 | train_freq=train_freq, |
| 127 | ) |
| 128 | model.learn(total_timesteps=150) |
| 129 | model.save(tmp_path / "test_save.zip") |
| 130 | env = model.get_env() |
| 131 | model = SAC.load(tmp_path / "test_save.zip", env=env) |
| 132 | model.learn(total_timesteps=150) |
| 133 | model = SAC.load(tmp_path / "test_save.zip", train_freq=train_freq, env=env) |
| 134 | model.learn(total_timesteps=150) |
| 135 | |
| 136 | |
| 137 | @pytest.mark.parametrize("train_freq", ["4", ("1", "episode"), "non_sense", (1, "close")]) |