Teacher-forced recomputation of per-token log-probs of ``sequences`` under ``model``. Mirrors the model's training shift: ``logits[:, t]`` predicts ``sequences[:, t+1]``, so the returned tensors have length ``T-1`` and are aligned to target positions ``1..T-1``. Args:
(
model,
sequences: torch.Tensor,
response_mask: torch.Tensor,
*,
temperature: float = 1.0,
requires_grad: bool = True,
)
| 231 | |
| 232 | |
| 233 | def compute_logprobs( |
| 234 | model, |
| 235 | sequences: torch.Tensor, |
| 236 | response_mask: torch.Tensor, |
| 237 | *, |
| 238 | temperature: float = 1.0, |
| 239 | requires_grad: bool = True, |
| 240 | ) -> tuple[torch.Tensor, torch.Tensor]: |
| 241 | """ |
| 242 | Teacher-forced recomputation of per-token log-probs of ``sequences`` under ``model``. |
| 243 | |
| 244 | Mirrors the model's training shift: ``logits[:, t]`` predicts ``sequences[:, t+1]``, |
| 245 | so the returned tensors have length ``T-1`` and are aligned to target positions |
| 246 | ``1..T-1``. |
| 247 | |
| 248 | Args: |
| 249 | sequences: (B, T) token ids (prompt + response). |
| 250 | response_mask: (B, T) bool over response positions (as from :class:`RolloutBatch`). |
| 251 | temperature: divide logits by this before log-softmax, matching sampling. |
| 252 | requires_grad: if False, runs under ``no_grad`` (use for ref / old-policy). |
| 253 | |
| 254 | Returns: |
| 255 | (logprobs, mask) each (B, T-1). ``logprobs`` is the log-prob of the realized |
| 256 | next token; ``mask`` is the response mask shifted to align with those targets. |
| 257 | """ |
| 258 | ctx = torch.enable_grad() if requires_grad else torch.no_grad() |
| 259 | with ctx: |
| 260 | logits = _logits_from(model, sequences)[:, :-1, :] # predict tokens 1..T-1 |
| 261 | logprobs_all = F.log_softmax(logits.float() / max(temperature, 1e-6), dim=-1) |
| 262 | targets = sequences[:, 1:].unsqueeze(-1) |
| 263 | logprobs = logprobs_all.gather(-1, targets).squeeze(-1) # (B, T-1) |
| 264 | mask = response_mask[:, 1:].to(torch.bool) |
| 265 | return logprobs, mask |
| 266 | |
| 267 | |
| 268 | def sequence_logprobs( |