MCPcopy
hub / github.com/SesameAILabs/csm / sample_topk

Function sample_topk

models.py:77–87  ·  view source on GitHub ↗
(logits: torch.Tensor, topk: int, temperature: float)

Source from the content-addressed store, hash-verified

75
76
77def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
78 logits = logits / temperature
79
80 filter_value: float = -float("Inf")
81 indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
82 scores_processed = logits.masked_fill(indices_to_remove, filter_value)
83 scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
84 probs = torch.nn.functional.softmax(scores_processed, dim=-1)
85
86 sample_token = _multinomial_sample_one_no_sync(probs)
87 return sample_token
88
89
90@dataclass

Callers 1

generate_frameMethod · 0.85

Calls 1

Tested by

no test coverage detected