MCPcopy
hub / github.com/DLR-RM/stable-baselines3 / train

Method train

stable_baselines3/sac/sac.py:202–302  ·  view source on GitHub ↗
(self, gradient_steps: int, batch_size: int = 64)

Source from the content-addressed store, hash-verified

200 self.critic_target = self.policy.critic_target
201
202 def train(self, gradient_steps: int, batch_size: int = 64) -> None:
203 # Switch to train mode (this affects batch norm / dropout)
204 self.policy.set_training_mode(True)
205 # Update optimizers learning rate
206 optimizers = [self.actor.optimizer, self.critic.optimizer]
207 if self.ent_coef_optimizer is not None:
208 optimizers += [self.ent_coef_optimizer]
209
210 # Update learning rate according to lr schedule
211 self._update_learning_rate(optimizers)
212
213 ent_coef_losses, ent_coefs = [], []
214 actor_losses, critic_losses = [], []
215
216 for gradient_step in range(gradient_steps):
217 # Sample replay buffer
218 replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
219 # For n-step replay, discount factor is gamma**n_steps (when no early termination)
220 discounts = replay_data.discounts if replay_data.discounts is not None else self.gamma
221
222 # We need to sample because `log_std` may have changed between two gradient steps
223 if self.use_sde:
224 self.actor.reset_noise()
225
226 # Action by the current actor for the sampled state
227 actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations)
228 log_prob = log_prob.reshape(-1, 1)
229
230 ent_coef_loss = None
231 if self.ent_coef_optimizer is not None and self.log_ent_coef is not None:
232 # Important: detach the variable from the graph
233 # so we don't change it with other losses
234 # see https://github.com/rail-berkeley/softlearning/issues/60
235 ent_coef = th.exp(self.log_ent_coef.detach())
236 assert isinstance(self.target_entropy, float)
237 ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
238 ent_coef_losses.append(ent_coef_loss.item())
239 else:
240 ent_coef = self.ent_coef_tensor
241
242 ent_coefs.append(ent_coef.item())
243
244 # Optimize entropy coefficient, also called
245 # entropy temperature or alpha in the paper
246 if ent_coef_loss is not None and self.ent_coef_optimizer is not None:
247 self.ent_coef_optimizer.zero_grad()
248 ent_coef_loss.backward()
249 self.ent_coef_optimizer.step()
250
251 with th.no_grad():
252 # Select action according to policy
253 next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
254 # Compute the next Q values: min over all critics targets
255 next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
256 next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
257 # add entropy term
258 next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1)
259 # td error + entropy term

Callers 1

Calls 8

polyak_updateFunction · 0.90
_update_learning_rateMethod · 0.80
action_log_probMethod · 0.80
recordMethod · 0.80
set_training_modeMethod · 0.45
sampleMethod · 0.45
reset_noiseMethod · 0.45
stepMethod · 0.45

Tested by 1