Update policy using the currently gathered rollout buffer (one gradient step over whole data).
(self)
| 130 | self._setup_model() |
| 131 | |
| 132 | def train(self) -> None: |
| 133 | """ |
| 134 | Update policy using the currently gathered |
| 135 | rollout buffer (one gradient step over whole data). |
| 136 | """ |
| 137 | # Switch to train mode (this affects batch norm / dropout) |
| 138 | self.policy.set_training_mode(True) |
| 139 | |
| 140 | # Update optimizer learning rate |
| 141 | self._update_learning_rate(self.policy.optimizer) |
| 142 | |
| 143 | # This will only loop once (get all data in one go) |
| 144 | for rollout_data in self.rollout_buffer.get(batch_size=None): |
| 145 | actions = rollout_data.actions |
| 146 | if isinstance(self.action_space, spaces.Discrete): |
| 147 | # Convert discrete action from float to long |
| 148 | actions = actions.long().flatten() |
| 149 | |
| 150 | values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) |
| 151 | values = values.flatten() |
| 152 | |
| 153 | # Normalize advantage (not present in the original implementation) |
| 154 | advantages = rollout_data.advantages |
| 155 | if self.normalize_advantage: |
| 156 | advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) |
| 157 | |
| 158 | # Policy gradient loss |
| 159 | policy_loss = -(advantages * log_prob).mean() |
| 160 | |
| 161 | # Value loss using the TD(gae_lambda) target |
| 162 | value_loss = F.mse_loss(rollout_data.returns, values) |
| 163 | |
| 164 | # Entropy loss favor exploration |
| 165 | if entropy is None: |
| 166 | # Approximate entropy when no analytical form |
| 167 | entropy_loss = -th.mean(-log_prob) |
| 168 | else: |
| 169 | entropy_loss = -th.mean(entropy) |
| 170 | |
| 171 | loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss |
| 172 | |
| 173 | # Optimization step |
| 174 | self.policy.optimizer.zero_grad() |
| 175 | loss.backward() |
| 176 | |
| 177 | # Clip grad norm |
| 178 | th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) |
| 179 | self.policy.optimizer.step() |
| 180 | |
| 181 | explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) |
| 182 | |
| 183 | self._n_updates += 1 |
| 184 | self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") |
| 185 | self.logger.record("train/explained_variance", explained_var) |
| 186 | self.logger.record("train/entropy_loss", entropy_loss.item()) |
| 187 | self.logger.record("train/policy_loss", policy_loss.item()) |
| 188 | self.logger.record("train/value_loss", value_loss.item()) |
| 189 | if hasattr(self.policy, "log_std"): |
nothing calls this directly
no test coverage detected