Clipped surrogate policy loss. Returns (loss, clip_fraction).
(
new_logp: torch.Tensor,
old_logp: torch.Tensor,
advantages: torch.Tensor,
mask: torch.Tensor,
clip: float = 0.2,
)
| 66 | |
| 67 | |
| 68 | def ppo_policy_loss( |
| 69 | new_logp: torch.Tensor, |
| 70 | old_logp: torch.Tensor, |
| 71 | advantages: torch.Tensor, |
| 72 | mask: torch.Tensor, |
| 73 | clip: float = 0.2, |
| 74 | ) -> tuple[torch.Tensor, torch.Tensor]: |
| 75 | """Clipped surrogate policy loss. Returns (loss, clip_fraction).""" |
| 76 | ratio = torch.exp(new_logp - old_logp) |
| 77 | surr1 = ratio * advantages |
| 78 | surr2 = torch.clamp(ratio, 1.0 - clip, 1.0 + clip) * advantages |
| 79 | loss = -masked_mean(torch.min(surr1, surr2), mask) |
| 80 | clipped = ((ratio - 1.0).abs() > clip).float() |
| 81 | return loss, masked_mean(clipped, mask) |
| 82 | |
| 83 | |
| 84 | def ppo_value_loss( |