MCPcopy Index your code
hub / github.com/zai-org/CodeGeeX / top_k_logits

Function top_k_logits

codegeex/torch/inference.py:87–114  ·  view source on GitHub ↗

This function has been mostly taken from huggingface conversational ai code at https://medium.com/huggingface/how-to-build-a-state-of-the-art- conversational-ai-with-transfer-learning-2d818ac26313

(logits, top_k=0, top_p=0.0, filter_value=-float("Inf"))

Source from the content-addressed store, hash-verified

85
86
87def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
88 """This function has been mostly taken from huggingface conversational
89 ai code at
90 https://medium.com/huggingface/how-to-build-a-state-of-the-art-
91 conversational-ai-with-transfer-learning-2d818ac26313"""
92
93 if top_k > 0:
94 # Remove all tokens with a probability less than the
95 # last token of the top-k
96 indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
97 logits[indices_to_remove] = filter_value
98
99 if top_p > 0.0:
100 # Cconvert to 1D
101 sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
102 cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
103
104 # Remove tokens with cumulative probability above the threshold
105 sorted_indices_to_remove = cumulative_probs > top_p
106 # Shift the indices to the right to keep also the first token
107 # above the threshold
108 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
109 sorted_indices_to_remove[..., 0] = 0
110 for i in range(sorted_indices.size(0)):
111 indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
112 logits[i][indices_to_remove] = filter_value
113
114 return logits
115
116
117def pad_batch(batch, pad_id, seq_length):

Callers 1

sample_sequence_batchFunction · 0.70

Calls 2

cumsumMethod · 0.80
sizeMethod · 0.80

Tested by

no test coverage detected