Update policy using the currently gathered rollout buffer.
(self)
| 182 | self.clip_range_vf = FloatSchedule(self.clip_range_vf) |
| 183 | |
| 184 | def train(self) -> None: |
| 185 | """ |
| 186 | Update policy using the currently gathered rollout buffer. |
| 187 | """ |
| 188 | # Switch to train mode (this affects batch norm / dropout) |
| 189 | self.policy.set_training_mode(True) |
| 190 | # Update optimizer learning rate |
| 191 | self._update_learning_rate(self.policy.optimizer) |
| 192 | # Compute current clip range |
| 193 | clip_range = self.clip_range(self._current_progress_remaining) # type: ignore[operator] |
| 194 | # Optional: clip range for the value function |
| 195 | if self.clip_range_vf is not None: |
| 196 | clip_range_vf = self.clip_range_vf(self._current_progress_remaining) # type: ignore[operator] |
| 197 | |
| 198 | entropy_losses = [] |
| 199 | pg_losses, value_losses = [], [] |
| 200 | clip_fractions = [] |
| 201 | |
| 202 | continue_training = True |
| 203 | # train for n_epochs epochs |
| 204 | for epoch in range(self.n_epochs): |
| 205 | approx_kl_divs = [] |
| 206 | # Do a complete pass on the rollout buffer |
| 207 | for rollout_data in self.rollout_buffer.get(self.batch_size): |
| 208 | actions = rollout_data.actions |
| 209 | if isinstance(self.action_space, spaces.Discrete): |
| 210 | # Convert discrete action from float to long |
| 211 | actions = rollout_data.actions.long().flatten() |
| 212 | |
| 213 | values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) |
| 214 | values = values.flatten() |
| 215 | # Normalize advantage |
| 216 | advantages = rollout_data.advantages |
| 217 | # Normalization does not make sense if mini batchsize == 1, see GH issue #325 |
| 218 | if self.normalize_advantage and len(advantages) > 1: |
| 219 | advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) |
| 220 | |
| 221 | # ratio between old and new policy, should be one at the first iteration |
| 222 | ratio = th.exp(log_prob - rollout_data.old_log_prob) |
| 223 | |
| 224 | # clipped surrogate loss |
| 225 | policy_loss_1 = advantages * ratio |
| 226 | policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) |
| 227 | policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() |
| 228 | |
| 229 | # Logging |
| 230 | pg_losses.append(policy_loss.item()) |
| 231 | clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() |
| 232 | clip_fractions.append(clip_fraction) |
| 233 | |
| 234 | if self.clip_range_vf is None: |
| 235 | # No clipping |
| 236 | values_pred = values |
| 237 | else: |
| 238 | # Clip the difference between old and new value |
| 239 | # NOTE: this depends on the reward scaling |
| 240 | values_pred = rollout_data.old_values + th.clamp( |
| 241 | values - rollout_data.old_values, -clip_range_vf, clip_range_vf |
nothing calls this directly
no test coverage detected