| 122 | self.state[:, :, :, -self.nc:] = obs |
| 123 | |
| 124 | def run(self): |
| 125 | mb_states, mb_rewards, mb_actions, mb_values, mb_dones = [], [], [], [], [] |
| 126 | for n in range(self.nsteps): |
| 127 | actions, values = self.agent.step(self.state) |
| 128 | mb_states.append(np.copy(self.state)) |
| 129 | mb_actions.append(actions) |
| 130 | mb_values.append(values) |
| 131 | mb_dones.append(self.dones) |
| 132 | obs, rewards, dones, infos = self.env.step(actions) |
| 133 | for done, info in zip(dones, infos): |
| 134 | if done: |
| 135 | self.total_rewards.append(info['reward']) |
| 136 | if info['total_reward'] != -1: |
| 137 | self.real_total_rewards.append(info['total_reward']) |
| 138 | self.dones = dones |
| 139 | for n, done in enumerate(dones): |
| 140 | if done: |
| 141 | self.state[n] = self.state[n] * 0 |
| 142 | self.update_state(obs) |
| 143 | mb_rewards.append(rewards) |
| 144 | mb_dones.append(self.dones) |
| 145 | # batch of steps to batch of rollouts |
| 146 | mb_states = np.asarray(mb_states, dtype=np.uint8).swapaxes(1, 0).reshape(self.batch_ob_shape) |
| 147 | mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0) |
| 148 | mb_actions = np.asarray(mb_actions, dtype=np.int32).swapaxes(1, 0) |
| 149 | mb_values = np.asarray(mb_values, dtype=np.float32).swapaxes(1, 0) |
| 150 | mb_dones = np.asarray(mb_dones, dtype=np.bool).swapaxes(1, 0) |
| 151 | mb_dones = mb_dones[:, 1:] |
| 152 | last_values = self.agent.value(self.state).tolist() |
| 153 | # discount/bootstrap off value fn |
| 154 | for n, (rewards, dones, value) in enumerate(zip(mb_rewards, mb_dones, last_values)): |
| 155 | rewards = rewards.tolist() |
| 156 | dones = dones.tolist() |
| 157 | if dones[-1] == 0: |
| 158 | rewards = discount_with_dones(rewards + [value], dones + [0], self.gamma)[:-1] |
| 159 | else: |
| 160 | rewards = discount_with_dones(rewards, dones, self.gamma) |
| 161 | mb_rewards[n] = rewards |
| 162 | mb_rewards = mb_rewards.flatten() |
| 163 | mb_actions = mb_actions.flatten() |
| 164 | mb_values = mb_values.flatten() |
| 165 | return mb_states, mb_rewards, mb_actions, mb_values |
| 166 | |
| 167 | |
| 168 | def learn(network, env, seed, new_session=True, nsteps=5, nstack=4, total_timesteps=int(80e6), |