(probs)
| 70 | |
| 71 | |
| 72 | def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization |
| 73 | q = torch.empty_like(probs).exponential_(1) |
| 74 | return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int) |
| 75 | |
| 76 | |
| 77 | def sample_topk(logits: torch.Tensor, topk: int, temperature: float): |