(probs, p)
| 9 | |
| 10 | |
| 11 | def sample_top_p(probs, p): |
| 12 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) |
| 13 | probs_sum = torch.cumsum(probs_sort, dim=-1) |
| 14 | mask = probs_sum - probs_sort > p |
| 15 | probs_sort[mask] = 0.0 |
| 16 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) |
| 17 | next_token = torch.multinomial(probs_sort, num_samples=1) |
| 18 | next_token = torch.gather(probs_idx, -1, next_token) |
| 19 | return next_token |
| 20 | |
| 21 | |
| 22 | class LLaMA_Evaluator(Evaluator): |