MCPcopy
hub / github.com/FareedKhan-dev/train-llm-from-scratch / compute_logprobs

Function compute_logprobs

src/post_training/rollout.py:233–265  ·  view source on GitHub ↗

Teacher-forced recomputation of per-token log-probs of ``sequences`` under ``model``. Mirrors the model's training shift: ``logits[:, t]`` predicts ``sequences[:, t+1]``, so the returned tensors have length ``T-1`` and are aligned to target positions ``1..T-1``. Args:

(
    model,
    sequences: torch.Tensor,
    response_mask: torch.Tensor,
    *,
    temperature: float = 1.0,
    requires_grad: bool = True,
)

Source from the content-addressed store, hash-verified

231
232
233def compute_logprobs(
234 model,
235 sequences: torch.Tensor,
236 response_mask: torch.Tensor,
237 *,
238 temperature: float = 1.0,
239 requires_grad: bool = True,
240) -> tuple[torch.Tensor, torch.Tensor]:
241 """
242 Teacher-forced recomputation of per-token log-probs of ``sequences`` under ``model``.
243
244 Mirrors the model's training shift: ``logits[:, t]`` predicts ``sequences[:, t+1]``,
245 so the returned tensors have length ``T-1`` and are aligned to target positions
246 ``1..T-1``.
247
248 Args:
249 sequences: (B, T) token ids (prompt + response).
250 response_mask: (B, T) bool over response positions (as from :class:`RolloutBatch`).
251 temperature: divide logits by this before log-softmax, matching sampling.
252 requires_grad: if False, runs under ``no_grad`` (use for ref / old-policy).
253
254 Returns:
255 (logprobs, mask) each (B, T-1). ``logprobs`` is the log-prob of the realized
256 next token; ``mask`` is the response mask shifted to align with those targets.
257 """
258 ctx = torch.enable_grad() if requires_grad else torch.no_grad()
259 with ctx:
260 logits = _logits_from(model, sequences)[:, :-1, :] # predict tokens 1..T-1
261 logprobs_all = F.log_softmax(logits.float() / max(temperature, 1e-6), dim=-1)
262 targets = sequences[:, 1:].unsqueeze(-1)
263 logprobs = logprobs_all.gather(-1, targets).squeeze(-1) # (B, T-1)
264 mask = response_mask[:, 1:].to(torch.bool)
265 return logprobs, mask
266
267
268def sequence_logprobs(

Callers 8

mainFunction · 0.90
mainFunction · 0.90
verify_grpo_optimizesFunction · 0.90
verify_ppo_optimizesFunction · 0.90
grpo_live.pyFile · 0.90
sequence_logprobsFunction · 0.85

Calls 1

_logits_fromFunction · 0.85