MCPcopy
hub / github.com/DLR-RM/stable-baselines3 / train

Method train

stable_baselines3/ppo/ppo.py:184–300  ·  view source on GitHub ↗

Update policy using the currently gathered rollout buffer.

(self)

Source from the content-addressed store, hash-verified

182 self.clip_range_vf = FloatSchedule(self.clip_range_vf)
183
184 def train(self) -> None:
185 """
186 Update policy using the currently gathered rollout buffer.
187 """
188 # Switch to train mode (this affects batch norm / dropout)
189 self.policy.set_training_mode(True)
190 # Update optimizer learning rate
191 self._update_learning_rate(self.policy.optimizer)
192 # Compute current clip range
193 clip_range = self.clip_range(self._current_progress_remaining) # type: ignore[operator]
194 # Optional: clip range for the value function
195 if self.clip_range_vf is not None:
196 clip_range_vf = self.clip_range_vf(self._current_progress_remaining) # type: ignore[operator]
197
198 entropy_losses = []
199 pg_losses, value_losses = [], []
200 clip_fractions = []
201
202 continue_training = True
203 # train for n_epochs epochs
204 for epoch in range(self.n_epochs):
205 approx_kl_divs = []
206 # Do a complete pass on the rollout buffer
207 for rollout_data in self.rollout_buffer.get(self.batch_size):
208 actions = rollout_data.actions
209 if isinstance(self.action_space, spaces.Discrete):
210 # Convert discrete action from float to long
211 actions = rollout_data.actions.long().flatten()
212
213 values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
214 values = values.flatten()
215 # Normalize advantage
216 advantages = rollout_data.advantages
217 # Normalization does not make sense if mini batchsize == 1, see GH issue #325
218 if self.normalize_advantage and len(advantages) > 1:
219 advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
220
221 # ratio between old and new policy, should be one at the first iteration
222 ratio = th.exp(log_prob - rollout_data.old_log_prob)
223
224 # clipped surrogate loss
225 policy_loss_1 = advantages * ratio
226 policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
227 policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
228
229 # Logging
230 pg_losses.append(policy_loss.item())
231 clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
232 clip_fractions.append(clip_fraction)
233
234 if self.clip_range_vf is None:
235 # No clipping
236 values_pred = values
237 else:
238 # Clip the difference between old and new value
239 # NOTE: this depends on the reward scaling
240 values_pred = rollout_data.old_values + th.clamp(
241 values - rollout_data.old_values, -clip_range_vf, clip_range_vf

Callers

nothing calls this directly

Calls 7

explained_varianceFunction · 0.90
_update_learning_rateMethod · 0.80
evaluate_actionsMethod · 0.80
recordMethod · 0.80
set_training_modeMethod · 0.45
getMethod · 0.45
stepMethod · 0.45

Tested by

no test coverage detected