MCPcopy
hub / github.com/DLR-RM/stable-baselines3 / test_deterministic_training_common

Function test_deterministic_training_common

tests/test_deterministic.py:12–40  ·  view source on GitHub ↗
(algo)

Source from the content-addressed store, hash-verified

10
11@pytest.mark.parametrize("algo", [A2C, DQN, PPO, SAC, TD3])
12def test_deterministic_training_common(algo):
13 results = [[], []]
14 rewards = [[], []]
15 # Smaller network
16 kwargs = {"policy_kwargs": dict(net_arch=[64])}
17 env_id = "Pendulum-v1"
18 if algo in [TD3, SAC]:
19 kwargs.update(
20 {"action_noise": NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)), "learning_starts": 100, "train_freq": 4}
21 )
22 else:
23 if algo == DQN:
24 env_id = "CartPole-v1"
25 kwargs.update({"learning_starts": 100, "target_update_interval": 100})
26 elif algo == PPO:
27 kwargs.update({"n_steps": 64, "n_epochs": 4})
28
29 for i in range(2):
30 model = algo("MlpPolicy", env_id, seed=SEED, **kwargs)
31 model.learn(N_STEPS_TRAINING)
32 env = model.get_env()
33 obs = env.reset()
34 for _ in range(100):
35 action, _ = model.predict(obs, deterministic=False)
36 obs, reward, _, _ = env.step(action)
37 results[i].append(action)
38 rewards[i].append(reward)
39 assert sum(results[0]) == sum(results[1]), results
40 assert sum(rewards[0]) == sum(rewards[1]), rewards

Callers

nothing calls this directly

Calls 7

NormalActionNoiseClass · 0.90
get_envMethod · 0.80
updateMethod · 0.45
learnMethod · 0.45
resetMethod · 0.45
predictMethod · 0.45
stepMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…