MCPcopy
hub / github.com/FareedKhan-dev/train-llm-from-scratch / grpo_loss

Function grpo_loss

src/post_training/grpo.py:37–70  ·  view source on GitHub ↗

Token-level clipped surrogate + KL penalty. Args: new_logp/old_logp/ref_logp: (B, L) per-token log-probs (policy / sampling / ref). advantages: (B,) one scalar per completion, broadcast over its tokens. resp_mask: (B, L) bool over response tokens. Returns:

(
    new_logp: torch.Tensor,
    old_logp: torch.Tensor,
    ref_logp: torch.Tensor,
    advantages: torch.Tensor,
    resp_mask: torch.Tensor,
    clip: float = 0.2,
    kl_coef: float = 0.04,
)

Source from the content-addressed store, hash-verified

35
36
37def grpo_loss(
38 new_logp: torch.Tensor,
39 old_logp: torch.Tensor,
40 ref_logp: torch.Tensor,
41 advantages: torch.Tensor,
42 resp_mask: torch.Tensor,
43 clip: float = 0.2,
44 kl_coef: float = 0.04,
45) -> tuple[torch.Tensor, dict]:
46 """
47 Token-level clipped surrogate + KL penalty.
48
49 Args:
50 new_logp/old_logp/ref_logp: (B, L) per-token log-probs (policy / sampling / ref).
51 advantages: (B,) one scalar per completion, broadcast over its tokens.
52 resp_mask: (B, L) bool over response tokens.
53
54 Returns:
55 (loss, stats) with mean KL and clip fraction for logging.
56 """
57 adv = advantages[:, None]
58 ratio = torch.exp(new_logp - old_logp)
59 surr1 = ratio * adv
60 surr2 = torch.clamp(ratio, 1.0 - clip, 1.0 + clip) * adv
61 surrogate = torch.min(surr1, surr2)
62 kl = k3_kl(new_logp, ref_logp)
63
64 per_token = surrogate - kl_coef * kl
65 loss = -masked_mean(per_token, resp_mask)
66 stats = {
67 "kl": masked_mean(kl, resp_mask).item(),
68 "clipfrac": masked_mean(((ratio - 1.0).abs() > clip).float(), resp_mask).item(),
69 }
70 return loss, stats

Callers 4

mainFunction · 0.90
test_grpo_loss_and_klFunction · 0.90
verify_grpo_optimizesFunction · 0.90
grpo_live.pyFile · 0.90

Calls 2

masked_meanFunction · 0.90
k3_klFunction · 0.85

Tested by 1

test_grpo_loss_and_klFunction · 0.72