| 200 | self.critic_target = self.policy.critic_target |
| 201 | |
| 202 | def train(self, gradient_steps: int, batch_size: int = 64) -> None: |
| 203 | # Switch to train mode (this affects batch norm / dropout) |
| 204 | self.policy.set_training_mode(True) |
| 205 | # Update optimizers learning rate |
| 206 | optimizers = [self.actor.optimizer, self.critic.optimizer] |
| 207 | if self.ent_coef_optimizer is not None: |
| 208 | optimizers += [self.ent_coef_optimizer] |
| 209 | |
| 210 | # Update learning rate according to lr schedule |
| 211 | self._update_learning_rate(optimizers) |
| 212 | |
| 213 | ent_coef_losses, ent_coefs = [], [] |
| 214 | actor_losses, critic_losses = [], [] |
| 215 | |
| 216 | for gradient_step in range(gradient_steps): |
| 217 | # Sample replay buffer |
| 218 | replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr] |
| 219 | # For n-step replay, discount factor is gamma**n_steps (when no early termination) |
| 220 | discounts = replay_data.discounts if replay_data.discounts is not None else self.gamma |
| 221 | |
| 222 | # We need to sample because `log_std` may have changed between two gradient steps |
| 223 | if self.use_sde: |
| 224 | self.actor.reset_noise() |
| 225 | |
| 226 | # Action by the current actor for the sampled state |
| 227 | actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations) |
| 228 | log_prob = log_prob.reshape(-1, 1) |
| 229 | |
| 230 | ent_coef_loss = None |
| 231 | if self.ent_coef_optimizer is not None and self.log_ent_coef is not None: |
| 232 | # Important: detach the variable from the graph |
| 233 | # so we don't change it with other losses |
| 234 | # see https://github.com/rail-berkeley/softlearning/issues/60 |
| 235 | ent_coef = th.exp(self.log_ent_coef.detach()) |
| 236 | assert isinstance(self.target_entropy, float) |
| 237 | ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean() |
| 238 | ent_coef_losses.append(ent_coef_loss.item()) |
| 239 | else: |
| 240 | ent_coef = self.ent_coef_tensor |
| 241 | |
| 242 | ent_coefs.append(ent_coef.item()) |
| 243 | |
| 244 | # Optimize entropy coefficient, also called |
| 245 | # entropy temperature or alpha in the paper |
| 246 | if ent_coef_loss is not None and self.ent_coef_optimizer is not None: |
| 247 | self.ent_coef_optimizer.zero_grad() |
| 248 | ent_coef_loss.backward() |
| 249 | self.ent_coef_optimizer.step() |
| 250 | |
| 251 | with th.no_grad(): |
| 252 | # Select action according to policy |
| 253 | next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) |
| 254 | # Compute the next Q values: min over all critics targets |
| 255 | next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1) |
| 256 | next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True) |
| 257 | # add entropy term |
| 258 | next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1) |
| 259 | # td error + entropy term |