(logits, top_k=0, top_p=0.0, filter_value=-float('Inf'))
| 22 | |
| 23 | |
| 24 | def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): |
| 25 | # This function has been mostly taken from huggingface conversational ai code at |
| 26 | # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 |
| 27 | |
| 28 | if top_k > 0: |
| 29 | # Remove all tokens with a probability less than the last token of the top-k |
| 30 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
| 31 | logits[indices_to_remove] = filter_value |
| 32 | |
| 33 | if top_p > 0.0: |
| 34 | # convert to 1D |
| 35 | logits = logits.view(logits.size()[1]).contiguous() |
| 36 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| 37 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| 38 | |
| 39 | # Remove tokens with cumulative probability above the threshold |
| 40 | sorted_indices_to_remove = cumulative_probs > top_p |
| 41 | # Shift the indices to the right to keep also the first token above the threshold |
| 42 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| 43 | sorted_indices_to_remove[..., 0] = 0 |
| 44 | indices_to_remove = sorted_indices[sorted_indices_to_remove] |
| 45 | logits[indices_to_remove] = filter_value |
| 46 | # going back to 2D |
| 47 | logits = logits.view(1, -1).contiguous() |
| 48 | |
| 49 | return logits |
| 50 | |
| 51 | def get_batch(context_tokens, device, args): |
| 52 | tokens = context_tokens |
no test coverage detected