(logits: torch.Tensor, topk: int, temperature: float)
| 75 | |
| 76 | |
| 77 | def sample_topk(logits: torch.Tensor, topk: int, temperature: float): |
| 78 | logits = logits / temperature |
| 79 | |
| 80 | filter_value: float = -float("Inf") |
| 81 | indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None] |
| 82 | scores_processed = logits.masked_fill(indices_to_remove, filter_value) |
| 83 | scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1) |
| 84 | probs = torch.nn.functional.softmax(scores_processed, dim=-1) |
| 85 | |
| 86 | sample_token = _multinomial_sample_one_no_sync(probs) |
| 87 | return sample_token |
| 88 | |
| 89 | |
| 90 | @dataclass |
no test coverage detected