MCPcopy
hub / github.com/FareedKhan-dev/train-llm-from-scratch / sequence_logprobs

Function sequence_logprobs

src/post_training/rollout.py:268–284  ·  view source on GitHub ↗

Sequence-level summed log-prob over response tokens (used by DPO/KTO/ORPO). Returns ``(sum_logprob, n_tokens)`` each shape (B,). The per-token mean is ``sum_logprob / n_tokens.clamp(min=1)``.

(
    model,
    sequences: torch.Tensor,
    response_mask: torch.Tensor,
    *,
    temperature: float = 1.0,
    requires_grad: bool = True,
)

Source from the content-addressed store, hash-verified

266
267
268def sequence_logprobs(
269 model,
270 sequences: torch.Tensor,
271 response_mask: torch.Tensor,
272 *,
273 temperature: float = 1.0,
274 requires_grad: bool = True,
275) -> tuple[torch.Tensor, torch.Tensor]:
276 """
277 Sequence-level summed log-prob over response tokens (used by DPO/KTO/ORPO).
278
279 Returns ``(sum_logprob, n_tokens)`` each shape (B,). The per-token mean is
280 ``sum_logprob / n_tokens.clamp(min=1)``.
281 """
282 lp, mask = compute_logprobs(model, sequences, response_mask, temperature=temperature, requires_grad=requires_grad)
283 m = mask.to(lp.dtype)
284 return (lp * m).sum(dim=-1), m.sum(dim=-1)
285
286
287def sequence_entropy(model, sequences: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:

Callers 1

_logpsFunction · 0.90

Calls 1

compute_logprobsFunction · 0.85

Tested by

no test coverage detected