Per-token unbiased, non-negative KL estimator (Schulman's k3) for KL(policy||ref).
(new_logp: torch.Tensor, ref_logp: torch.Tensor)
| 29 | |
| 30 | |
| 31 | def k3_kl(new_logp: torch.Tensor, ref_logp: torch.Tensor) -> torch.Tensor: |
| 32 | """Per-token unbiased, non-negative KL estimator (Schulman's k3) for KL(policy||ref).""" |
| 33 | diff = ref_logp - new_logp |
| 34 | return torch.exp(diff) - diff - 1.0 |
| 35 | |
| 36 | |
| 37 | def grpo_loss( |
no outgoing calls