MCPcopy
hub / github.com/DLR-RM/stable-baselines3 / train

Method train

stable_baselines3/a2c/a2c.py:132–190  ·  view source on GitHub ↗

Update policy using the currently gathered rollout buffer (one gradient step over whole data).

(self)

Source from the content-addressed store, hash-verified

130 self._setup_model()
131
132 def train(self) -> None:
133 """
134 Update policy using the currently gathered
135 rollout buffer (one gradient step over whole data).
136 """
137 # Switch to train mode (this affects batch norm / dropout)
138 self.policy.set_training_mode(True)
139
140 # Update optimizer learning rate
141 self._update_learning_rate(self.policy.optimizer)
142
143 # This will only loop once (get all data in one go)
144 for rollout_data in self.rollout_buffer.get(batch_size=None):
145 actions = rollout_data.actions
146 if isinstance(self.action_space, spaces.Discrete):
147 # Convert discrete action from float to long
148 actions = actions.long().flatten()
149
150 values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
151 values = values.flatten()
152
153 # Normalize advantage (not present in the original implementation)
154 advantages = rollout_data.advantages
155 if self.normalize_advantage:
156 advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
157
158 # Policy gradient loss
159 policy_loss = -(advantages * log_prob).mean()
160
161 # Value loss using the TD(gae_lambda) target
162 value_loss = F.mse_loss(rollout_data.returns, values)
163
164 # Entropy loss favor exploration
165 if entropy is None:
166 # Approximate entropy when no analytical form
167 entropy_loss = -th.mean(-log_prob)
168 else:
169 entropy_loss = -th.mean(entropy)
170
171 loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
172
173 # Optimization step
174 self.policy.optimizer.zero_grad()
175 loss.backward()
176
177 # Clip grad norm
178 th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
179 self.policy.optimizer.step()
180
181 explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
182
183 self._n_updates += 1
184 self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
185 self.logger.record("train/explained_variance", explained_var)
186 self.logger.record("train/entropy_loss", entropy_loss.item())
187 self.logger.record("train/policy_loss", policy_loss.item())
188 self.logger.record("train/value_loss", value_loss.item())
189 if hasattr(self.policy, "log_std"):

Callers

nothing calls this directly

Calls 7

explained_varianceFunction · 0.90
_update_learning_rateMethod · 0.80
evaluate_actionsMethod · 0.80
recordMethod · 0.80
set_training_modeMethod · 0.45
getMethod · 0.45
stepMethod · 0.45

Tested by

no test coverage detected