| 96 | return obs, extras |
| 97 | |
| 98 | def step(self, action: torch.Tensor) -> VecEnvStepReturn: |
| 99 | # split single-agent actions to build the multi-agent ones |
| 100 | # FIXME: This implementation assumes the spaces are fundamental ones. Fix it to support composite spaces |
| 101 | index = 0 |
| 102 | _actions = {} |
| 103 | for agent in self.env.possible_agents: |
| 104 | delta = gym.spaces.flatdim(self.env.action_spaces[agent]) |
| 105 | _actions[agent] = action[:, index : index + delta] |
| 106 | index += delta |
| 107 | |
| 108 | # step the environment |
| 109 | obs, rewards, terminated, time_outs, extras = self.env.step(_actions) |
| 110 | |
| 111 | # use environment state as observation |
| 112 | if self._state_as_observation: |
| 113 | obs = {"policy": self.env.state()} |
| 114 | # concatenate agents' observations |
| 115 | # FIXME: This implementation assumes the spaces are fundamental ones. Fix it to support composite spaces |
| 116 | else: |
| 117 | obs = { |
| 118 | "policy": torch.cat( |
| 119 | [obs[agent].reshape(self.num_envs, -1) for agent in self.env.possible_agents], dim=-1 |
| 120 | ) |
| 121 | } |
| 122 | |
| 123 | # process environment outputs to return single-agent data |
| 124 | rewards = sum(rewards.values()) |
| 125 | terminated = math.prod(terminated.values()).to(dtype=torch.bool) |
| 126 | time_outs = math.prod(time_outs.values()).to(dtype=torch.bool) |
| 127 | |
| 128 | return obs, rewards, terminated, time_outs, extras |
| 129 | |
| 130 | def render(self, recompute: bool = False) -> np.ndarray | None: |
| 131 | return self.env.render(recompute) |