This function has been mostly taken from huggingface conversational ai code at https://medium.com/huggingface/how-to-build-a-state-of-the-art- conversational-ai-with-transfer-learning-2d818ac26313
(logits, top_k=0, top_p=0.0, filter_value=-float("Inf"))
| 85 | |
| 86 | |
| 87 | def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): |
| 88 | """This function has been mostly taken from huggingface conversational |
| 89 | ai code at |
| 90 | https://medium.com/huggingface/how-to-build-a-state-of-the-art- |
| 91 | conversational-ai-with-transfer-learning-2d818ac26313""" |
| 92 | |
| 93 | if top_k > 0: |
| 94 | # Remove all tokens with a probability less than the |
| 95 | # last token of the top-k |
| 96 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
| 97 | logits[indices_to_remove] = filter_value |
| 98 | |
| 99 | if top_p > 0.0: |
| 100 | # Cconvert to 1D |
| 101 | sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) |
| 102 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| 103 | |
| 104 | # Remove tokens with cumulative probability above the threshold |
| 105 | sorted_indices_to_remove = cumulative_probs > top_p |
| 106 | # Shift the indices to the right to keep also the first token |
| 107 | # above the threshold |
| 108 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| 109 | sorted_indices_to_remove[..., 0] = 0 |
| 110 | for i in range(sorted_indices.size(0)): |
| 111 | indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] |
| 112 | logits[i][indices_to_remove] = filter_value |
| 113 | |
| 114 | return logits |
| 115 | |
| 116 | |
| 117 | def pad_batch(batch, pad_id, seq_length): |
no test coverage detected