MCPcopy Index your code
hub / github.com/THUDM/GLM / top_k_logits

Function top_k_logits

generation_utils.py:84–109  ·  view source on GitHub ↗
(logits, top_k=0, top_p=0.0, filter_value=-float('Inf'))

Source from the content-addressed store, hash-verified

82
83
84def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
85 # This function has been mostly taken from huggingface conversational ai code at
86 # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
87
88 if top_k > 0:
89 # Remove all tokens with a probability less than the last token of the top-k
90 indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
91 logits[indices_to_remove] = filter_value
92
93 if top_p > 0.0:
94 # convert to 1D
95 logits = logits.view(logits.size()[1]).contiguous()
96 sorted_logits, sorted_indices = torch.sort(logits, descending=True)
97 cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
98
99 # Remove tokens with cumulative probability above the threshold
100 sorted_indices_to_remove = cumulative_probs > top_p
101 # Shift the indices to the right to keep also the first token above the threshold
102 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
103 sorted_indices_to_remove[..., 0] = 0
104 indices_to_remove = sorted_indices[sorted_indices_to_remove]
105 logits[indices_to_remove] = filter_value
106 # going back to 2D
107 logits = logits.view(1, -1).contiguous()
108
109 return logits
110
111
112class BeamScorer(ABC):

Callers 2

sample_sequenceFunction · 0.90
evaluateMethod · 0.90

Calls 1

cumsumMethod · 0.80

Tested by

no test coverage detected