(logits: torch.Tensor, temperature: float = 0.0)
| 46 | |
| 47 | |
| 48 | def sample(logits: torch.Tensor, temperature: float = 0.0) -> torch.Tensor: |
| 49 | if temperature < 1e-5: |
| 50 | return torch.argmax(logits, dim=-1) |
| 51 | bsz, seq_len, vocab_size = logits.shape |
| 52 | logits = logits.view(-1, vocab_size) / temperature |
| 53 | probs = torch.softmax(logits, dim=-1) |
| 54 | return torch.multinomial(probs, num_samples=1).view(bsz, seq_len) |
| 55 | |
| 56 | |
| 57 | def _cuda_time() -> float: |