(logits, top_k=0, top_p=0.0, filter_value=-float('Inf'))
| 98 | return tokens, attention_mask, position_ids |
| 99 | |
| 100 | def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): |
| 101 | # This function has been mostly taken from huggingface conversational ai code at |
| 102 | # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 |
| 103 | |
| 104 | if top_k > 0: |
| 105 | # Remove all tokens with a probability less than the last token of the top-k |
| 106 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
| 107 | logits[indices_to_remove] = filter_value |
| 108 | |
| 109 | if top_p > 0.0: |
| 110 | #convert to 1D |
| 111 | logits=logits.view(logits.size()[1]).contiguous() |
| 112 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| 113 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| 114 | |
| 115 | # Remove tokens with cumulative probability above the threshold |
| 116 | sorted_indices_to_remove = cumulative_probs > top_p |
| 117 | # Shift the indices to the right to keep also the first token above the threshold |
| 118 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| 119 | sorted_indices_to_remove[..., 0] = 0 |
| 120 | indices_to_remove = sorted_indices[sorted_indices_to_remove] |
| 121 | logits[indices_to_remove] = filter_value |
| 122 | #going back to 2D |
| 123 | logits=logits.view(1, -1).contiguous() |
| 124 | |
| 125 | return logits |
| 126 | |
| 127 | |
| 128 | def generate_samples(model, tokenizer, args, device): |
no test coverage detected