MCPcopy
hub / github.com/z-lab/dflash / sample

Function sample

dflash/model.py:48–54  ·  view source on GitHub ↗
(logits: torch.Tensor, temperature: float = 0.0)

Source from the content-addressed store, hash-verified

46
47
48def sample(logits: torch.Tensor, temperature: float = 0.0) -> torch.Tensor:
49 if temperature < 1e-5:
50 return torch.argmax(logits, dim=-1)
51 bsz, seq_len, vocab_size = logits.shape
52 logits = logits.view(-1, vocab_size) / temperature
53 probs = torch.softmax(logits, dim=-1)
54 return torch.multinomial(probs, num_samples=1).view(bsz, seq_len)
55
56
57def _cuda_time() -> float:

Callers 1

dflash_generateFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected