Step env helper.
(env, max_path_length=100, iterations=1, render=True)
| 2 | |
| 3 | |
| 4 | def step_env(env, max_path_length=100, iterations=1, render=True): |
| 5 | """Step env helper.""" |
| 6 | for _ in range(iterations): |
| 7 | obs = env.reset()[0] |
| 8 | for _ in range(max_path_length): |
| 9 | next_obs, _, terminated, truncated, info = env.step( |
| 10 | env.action_space.sample() |
| 11 | ) |
| 12 | if env._partially_observable: |
| 13 | assert (next_obs[-3:] == np.zeros(3)).all() |
| 14 | else: |
| 15 | assert (next_obs[-3:] == env._get_pos_goal()).all() |
| 16 | assert (next_obs[:3] == env.get_endeff_pos()).all() |
| 17 | internal_obs = env._get_pos_objects() |
| 18 | internal_quat = env._get_quat_objects() |
| 19 | assert (next_obs[4:7] == internal_obs[:3]).all() |
| 20 | assert (next_obs[7:11] == internal_quat[:4]).all() |
| 21 | if internal_obs.shape == (6,): |
| 22 | assert internal_quat.shape == (8,) |
| 23 | assert (next_obs[11:14] == internal_obs[3:]).all() |
| 24 | assert (next_obs[14:18] == internal_quat[4:]).all() |
| 25 | else: |
| 26 | assert (next_obs[11:14] == np.zeros(3)).all() |
| 27 | assert (next_obs[14:18] == np.zeros(4)).all() |
| 28 | assert (obs[:18] == next_obs[18:-3]).all() |
| 29 | obs = next_obs |
| 30 | if render: |
| 31 | env.render() |
| 32 | if terminated or truncated: |
| 33 | break |
searching dependent graphs…