(logits, top_k=0, top_p=0.0, filter_value=-float('Inf'))
| 82 | |
| 83 | |
| 84 | def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): |
| 85 | # This function has been mostly taken from huggingface conversational ai code at |
| 86 | # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 |
| 87 | |
| 88 | if top_k > 0: |
| 89 | # Remove all tokens with a probability less than the last token of the top-k |
| 90 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
| 91 | logits[indices_to_remove] = filter_value |
| 92 | |
| 93 | if top_p > 0.0: |
| 94 | # convert to 1D |
| 95 | logits = logits.view(logits.size()[1]).contiguous() |
| 96 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| 97 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| 98 | |
| 99 | # Remove tokens with cumulative probability above the threshold |
| 100 | sorted_indices_to_remove = cumulative_probs > top_p |
| 101 | # Shift the indices to the right to keep also the first token above the threshold |
| 102 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| 103 | sorted_indices_to_remove[..., 0] = 0 |
| 104 | indices_to_remove = sorted_indices[sorted_indices_to_remove] |
| 105 | logits[indices_to_remove] = filter_value |
| 106 | # going back to 2D |
| 107 | logits = logits.view(1, -1).contiguous() |
| 108 | |
| 109 | return logits |
| 110 | |
| 111 | |
| 112 | class BeamScorer(ABC): |
no test coverage detected