Calculate per-token perplexity.
(logits, output_ids)
| 1 | def ppl(logits, output_ids): |
| 2 | """ |
| 3 | Calculate per-token perplexity. |
| 4 | """ |
| 5 | nlls = -logits.log_softmax(dim=-1) |
| 6 | ppls = nlls.gather(-1, output_ids.long().unsqueeze(-1)) |
| 7 | return ppls.mean().exp().item() |
no test coverage detected