Sequence-level summed log-prob over response tokens (used by DPO/KTO/ORPO). Returns ``(sum_logprob, n_tokens)`` each shape (B,). The per-token mean is ``sum_logprob / n_tokens.clamp(min=1)``.
(
model,
sequences: torch.Tensor,
response_mask: torch.Tensor,
*,
temperature: float = 1.0,
requires_grad: bool = True,
)
| 266 | |
| 267 | |
| 268 | def sequence_logprobs( |
| 269 | model, |
| 270 | sequences: torch.Tensor, |
| 271 | response_mask: torch.Tensor, |
| 272 | *, |
| 273 | temperature: float = 1.0, |
| 274 | requires_grad: bool = True, |
| 275 | ) -> tuple[torch.Tensor, torch.Tensor]: |
| 276 | """ |
| 277 | Sequence-level summed log-prob over response tokens (used by DPO/KTO/ORPO). |
| 278 | |
| 279 | Returns ``(sum_logprob, n_tokens)`` each shape (B,). The per-token mean is |
| 280 | ``sum_logprob / n_tokens.clamp(min=1)``. |
| 281 | """ |
| 282 | lp, mask = compute_logprobs(model, sequences, response_mask, temperature=temperature, requires_grad=requires_grad) |
| 283 | m = mask.to(lp.dtype) |
| 284 | return (lp * m).sum(dim=-1), m.sum(dim=-1) |
| 285 | |
| 286 | |
| 287 | def sequence_entropy(model, sequences: torch.Tensor, temperature: float = 1.0) -> torch.Tensor: |
no test coverage detected