(train_freq)
| 136 | |
| 137 | @pytest.mark.parametrize("train_freq", ["4", ("1", "episode"), "non_sense", (1, "close")]) |
| 138 | def test_train_freq_fail(train_freq): |
| 139 | with pytest.raises(ValueError): |
| 140 | model = SAC( |
| 141 | "MlpPolicy", |
| 142 | "Pendulum-v1", |
| 143 | policy_kwargs=dict(net_arch=[64, 64], n_critics=1), |
| 144 | learning_starts=100, |
| 145 | buffer_size=10000, |
| 146 | verbose=1, |
| 147 | train_freq=train_freq, |
| 148 | ) |
| 149 | model.learn(total_timesteps=250) |
| 150 | |
| 151 | |
| 152 | @pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, DQN]) |