(algo)
| 10 | |
| 11 | @pytest.mark.parametrize("algo", [A2C, DQN, PPO, SAC, TD3]) |
| 12 | def test_deterministic_training_common(algo): |
| 13 | results = [[], []] |
| 14 | rewards = [[], []] |
| 15 | # Smaller network |
| 16 | kwargs = {"policy_kwargs": dict(net_arch=[64])} |
| 17 | env_id = "Pendulum-v1" |
| 18 | if algo in [TD3, SAC]: |
| 19 | kwargs.update( |
| 20 | {"action_noise": NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)), "learning_starts": 100, "train_freq": 4} |
| 21 | ) |
| 22 | else: |
| 23 | if algo == DQN: |
| 24 | env_id = "CartPole-v1" |
| 25 | kwargs.update({"learning_starts": 100, "target_update_interval": 100}) |
| 26 | elif algo == PPO: |
| 27 | kwargs.update({"n_steps": 64, "n_epochs": 4}) |
| 28 | |
| 29 | for i in range(2): |
| 30 | model = algo("MlpPolicy", env_id, seed=SEED, **kwargs) |
| 31 | model.learn(N_STEPS_TRAINING) |
| 32 | env = model.get_env() |
| 33 | obs = env.reset() |
| 34 | for _ in range(100): |
| 35 | action, _ = model.predict(obs, deterministic=False) |
| 36 | obs, reward, _, _ = env.step(action) |
| 37 | results[i].append(action) |
| 38 | rewards[i].append(reward) |
| 39 | assert sum(results[0]) == sum(results[1]), results |
| 40 | assert sum(rewards[0]) == sum(rewards[1]), rewards |
nothing calls this directly
no test coverage detected
searching dependent graphs…