MCPcopy
hub / github.com/zai-org/CogView / top_k_logits

Function top_k_logits

generation/sampling.py:24–49  ·  view source on GitHub ↗
(logits, top_k=0, top_p=0.0, filter_value=-float('Inf'))

Source from the content-addressed store, hash-verified

22
23
24def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
25 # This function has been mostly taken from huggingface conversational ai code at
26 # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
27
28 if top_k > 0:
29 # Remove all tokens with a probability less than the last token of the top-k
30 indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
31 logits[indices_to_remove] = filter_value
32
33 if top_p > 0.0:
34 # convert to 1D
35 logits = logits.view(logits.size()[1]).contiguous()
36 sorted_logits, sorted_indices = torch.sort(logits, descending=True)
37 cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
38
39 # Remove tokens with cumulative probability above the threshold
40 sorted_indices_to_remove = cumulative_probs > top_p
41 # Shift the indices to the right to keep also the first token above the threshold
42 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
43 sorted_indices_to_remove[..., 0] = 0
44 indices_to_remove = sorted_indices[sorted_indices_to_remove]
45 logits[indices_to_remove] = filter_value
46 # going back to 2D
47 logits = logits.view(1, -1).contiguous()
48
49 return logits
50
51def get_batch(context_tokens, device, args):
52 tokens = context_tokens

Callers 1

filling_sequenceFunction · 0.85

Calls 1

cumsumMethod · 0.80

Tested by

no test coverage detected