Frame stacking wrapper for vectorized environment. Designed for image observations. :param venv: Vectorized environment to wrap :param n_stack: Number of frames to stack :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension. I
| 9 | |
| 10 | |
| 11 | class VecFrameStack(VecEnvWrapper): |
| 12 | """ |
| 13 | Frame stacking wrapper for vectorized environment. Designed for image observations. |
| 14 | |
| 15 | :param venv: Vectorized environment to wrap |
| 16 | :param n_stack: Number of frames to stack |
| 17 | :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension. |
| 18 | If None, automatically detect channel to stack over in case of image observation or default to "last" (default). |
| 19 | Alternatively channels_order can be a dictionary which can be used with environments with Dict observation spaces |
| 20 | """ |
| 21 | |
| 22 | def __init__(self, venv: VecEnv, n_stack: int, channels_order: str | Mapping[str, str] | None = None) -> None: |
| 23 | assert isinstance( |
| 24 | venv.observation_space, (spaces.Box, spaces.Dict) |
| 25 | ), "VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces" |
| 26 | |
| 27 | self.stacked_obs = StackedObservations(venv.num_envs, n_stack, venv.observation_space, channels_order) |
| 28 | observation_space = self.stacked_obs.stacked_observation_space |
| 29 | super().__init__(venv, observation_space=observation_space) |
| 30 | |
| 31 | def step_wait( |
| 32 | self, |
| 33 | ) -> tuple[ |
| 34 | np.ndarray | dict[str, np.ndarray], |
| 35 | np.ndarray, |
| 36 | np.ndarray, |
| 37 | list[dict[str, Any]], |
| 38 | ]: |
| 39 | observations, rewards, dones, infos = self.venv.step_wait() |
| 40 | observations, infos = self.stacked_obs.update(observations, dones, infos) # type: ignore[arg-type] |
| 41 | return observations, rewards, dones, infos |
| 42 | |
| 43 | def reset(self) -> np.ndarray | dict[str, np.ndarray]: |
| 44 | """ |
| 45 | Reset all environments |
| 46 | """ |
| 47 | observation = self.venv.reset() |
| 48 | observation = self.stacked_obs.reset(observation) # type: ignore[arg-type] |
| 49 | return observation |
no outgoing calls
searching dependent graphs…