Step the environment. Args: action: The action to take. Must be a 4 element array of floats. Returns: The (next_obs, reward, terminated, truncated, info) tuple.
(
self, action: npt.NDArray[np.float32]
)
| 578 | |
| 579 | @_Decorators.assert_task_is_set |
| 580 | def step( |
| 581 | self, action: npt.NDArray[np.float32] |
| 582 | ) -> tuple[npt.NDArray[np.float64], SupportsFloat, bool, bool, dict[str, Any]]: |
| 583 | """Step the environment. |
| 584 | |
| 585 | Args: |
| 586 | action: The action to take. Must be a 4 element array of floats. |
| 587 | |
| 588 | Returns: |
| 589 | The (next_obs, reward, terminated, truncated, info) tuple. |
| 590 | """ |
| 591 | assert len(action) == 4, f"Actions should be size 4, got {len(action)}" |
| 592 | self.set_xyz_action(action[:3]) |
| 593 | if self.curr_path_length >= self.max_path_length: |
| 594 | raise ValueError("You must reset the env manually once truncate==True") |
| 595 | self.do_simulation([action[-1], -action[-1]], n_frames=self.frame_skip) |
| 596 | self.curr_path_length += 1 |
| 597 | |
| 598 | # Running the simulator can sometimes mess up site positions, so |
| 599 | # re-position them here to make sure they're accurate |
| 600 | for site in self._target_site_config: |
| 601 | self._set_pos_site(*site) |
| 602 | |
| 603 | if self._did_see_sim_exception: |
| 604 | assert self._last_stable_obs is not None |
| 605 | return ( |
| 606 | self._last_stable_obs, # observation just before going unstable |
| 607 | 0.0, # reward (penalize for causing instability) |
| 608 | False, |
| 609 | False, # termination flag always False |
| 610 | { # info |
| 611 | "success": False, |
| 612 | "near_object": 0.0, |
| 613 | "grasp_success": False, |
| 614 | "grasp_reward": 0.0, |
| 615 | "in_place_reward": 0.0, |
| 616 | "obj_to_target": 0.0, |
| 617 | "unscaled_reward": 0.0, |
| 618 | }, |
| 619 | ) |
| 620 | mujoco.mj_forward(self.model, self.data) |
| 621 | self._last_stable_obs = self._get_obs() |
| 622 | |
| 623 | self._last_stable_obs = np.clip( |
| 624 | self._last_stable_obs, |
| 625 | a_max=self.sawyer_observation_space.high, |
| 626 | a_min=self.sawyer_observation_space.low, |
| 627 | dtype=np.float64, |
| 628 | ) |
| 629 | assert isinstance(self._last_stable_obs, np.ndarray) |
| 630 | reward, info = self.evaluate_state(self._last_stable_obs, action) |
| 631 | # step will never return a terminate==True if there is a success |
| 632 | # but we can return truncate=True if the current path length == max path length |
| 633 | truncate = False |
| 634 | if self.curr_path_length == self.max_path_length: |
| 635 | truncate = True |
| 636 | return ( |
| 637 | np.array(self._last_stable_obs, dtype=np.float64), |
no test coverage detected