(show_viewer)
| 841 | @pytest.mark.slow # ~200s |
| 842 | @pytest.mark.required |
| 843 | def test_reset(show_viewer): |
| 844 | BOOL_MASK = torch.tensor([True, False, True, False], dtype=torch.bool, device=gs.device) |
| 845 | |
| 846 | scene = gs.Scene( |
| 847 | show_viewer=show_viewer, |
| 848 | ) |
| 849 | scene.add_entity( |
| 850 | gs.morphs.URDF( |
| 851 | file="urdf/plane/plane.urdf", |
| 852 | fixed=True, |
| 853 | ) |
| 854 | ) |
| 855 | scene.add_entity( |
| 856 | gs.morphs.Box( |
| 857 | size=(0.1, 0.1, 0.1), |
| 858 | pos=(0, 0, 0.5), |
| 859 | ) |
| 860 | ) |
| 861 | scene.build(n_envs=4) |
| 862 | |
| 863 | init_state = scene.get_state() |
| 864 | init_rigid_state = next(s for s in init_state.solvers_state if isinstance(s, RigidSolverState)) |
| 865 | for _ in range(50): |
| 866 | scene.step() |
| 867 | fallen_state = scene.get_state() |
| 868 | fallen_rigid_state = next(s for s in fallen_state.solvers_state if isinstance(s, RigidSolverState)) |
| 869 | |
| 870 | for envs_idx in (BOOL_MASK, torch.where(BOOL_MASK)[0]): |
| 871 | scene.reset(state=fallen_state) |
| 872 | scene.reset(state=init_state, envs_idx=envs_idx) |
| 873 | for actual, init_ref, fallen_ref in ( |
| 874 | ( |
| 875 | qd_to_torch(scene.rigid_solver._rigid_global_info.qpos, transpose=True, copy=True), |
| 876 | init_rigid_state.qpos, |
| 877 | fallen_rigid_state.qpos, |
| 878 | ), |
| 879 | ( |
| 880 | qd_to_torch(scene.rigid_solver.dofs_state.vel, transpose=True, copy=True), |
| 881 | init_rigid_state.dofs_vel, |
| 882 | fallen_rigid_state.dofs_vel, |
| 883 | ), |
| 884 | ( |
| 885 | qd_to_torch(scene.rigid_solver.links_state.pos, transpose=True, copy=True), |
| 886 | init_rigid_state.links_pos, |
| 887 | fallen_rigid_state.links_pos, |
| 888 | ), |
| 889 | ): |
| 890 | assert_allclose(actual[BOOL_MASK], init_ref[BOOL_MASK], tol=gs.EPS) |
| 891 | assert_allclose(actual[~BOOL_MASK], fallen_ref[~BOOL_MASK], tol=gs.EPS) |
| 892 | |
| 893 | # After reset, simulation from init_state should reproduce the original fallen_state trajectory |
| 894 | for _ in range(50): |
| 895 | scene.step() |
| 896 | for actual, fallen_ref in ( |
| 897 | (qd_to_torch(scene.rigid_solver._rigid_global_info.qpos, transpose=True, copy=True), fallen_rigid_state.qpos), |
| 898 | (qd_to_torch(scene.rigid_solver.dofs_state.vel, transpose=True, copy=True), fallen_rigid_state.dofs_vel), |
| 899 | (qd_to_torch(scene.rigid_solver.links_state.pos, transpose=True, copy=True), fallen_rigid_state.links_pos), |
| 900 | ): |
nothing calls this directly
no test coverage detected