Normalize advantages to zero mean / unit std over masked (response) positions.
(advantages: torch.Tensor, mask: torch.Tensor)
| 58 | |
| 59 | |
| 60 | def whiten(advantages: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| 61 | """Normalize advantages to zero mean / unit std over masked (response) positions.""" |
| 62 | m = mask.float() |
| 63 | mean = masked_mean(advantages, m) |
| 64 | var = masked_mean((advantages - mean) ** 2, m) |
| 65 | return ((advantages - mean) / (var.sqrt() + 1e-8)) * m |
| 66 | |
| 67 | |
| 68 | def ppo_policy_loss( |