Token-level clipped surrogate + KL penalty. Args: new_logp/old_logp/ref_logp: (B, L) per-token log-probs (policy / sampling / ref). advantages: (B,) one scalar per completion, broadcast over its tokens. resp_mask: (B, L) bool over response tokens. Returns:
(
new_logp: torch.Tensor,
old_logp: torch.Tensor,
ref_logp: torch.Tensor,
advantages: torch.Tensor,
resp_mask: torch.Tensor,
clip: float = 0.2,
kl_coef: float = 0.04,
)
| 35 | |
| 36 | |
| 37 | def grpo_loss( |
| 38 | new_logp: torch.Tensor, |
| 39 | old_logp: torch.Tensor, |
| 40 | ref_logp: torch.Tensor, |
| 41 | advantages: torch.Tensor, |
| 42 | resp_mask: torch.Tensor, |
| 43 | clip: float = 0.2, |
| 44 | kl_coef: float = 0.04, |
| 45 | ) -> tuple[torch.Tensor, dict]: |
| 46 | """ |
| 47 | Token-level clipped surrogate + KL penalty. |
| 48 | |
| 49 | Args: |
| 50 | new_logp/old_logp/ref_logp: (B, L) per-token log-probs (policy / sampling / ref). |
| 51 | advantages: (B,) one scalar per completion, broadcast over its tokens. |
| 52 | resp_mask: (B, L) bool over response tokens. |
| 53 | |
| 54 | Returns: |
| 55 | (loss, stats) with mean KL and clip fraction for logging. |
| 56 | """ |
| 57 | adv = advantages[:, None] |
| 58 | ratio = torch.exp(new_logp - old_logp) |
| 59 | surr1 = ratio * adv |
| 60 | surr2 = torch.clamp(ratio, 1.0 - clip, 1.0 + clip) * adv |
| 61 | surrogate = torch.min(surr1, surr2) |
| 62 | kl = k3_kl(new_logp, ref_logp) |
| 63 | |
| 64 | per_token = surrogate - kl_coef * kl |
| 65 | loss = -masked_mean(per_token, resp_mask) |
| 66 | stats = { |
| 67 | "kl": masked_mean(kl, resp_mask).item(), |
| 68 | "clipfrac": masked_mean(((ratio - 1.0).abs() > clip).float(), resp_mask).item(), |
| 69 | } |
| 70 | return loss, stats |