MCPcopy
hub / github.com/FoundationVision/LlamaGen / _multinomial

Function _multinomial

autoregressive/serve/sampler.py:383–407  ·  view source on GitHub ↗
(
    probs: torch.Tensor,
    num_samples: int,
    seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
    generators: Optional[List[torch.Generator]] = None,
)

Source from the content-addressed store, hash-verified

381# probs will be modified in place, but this is fine, as we pass
382# in a copy already.
383def _multinomial(
384 probs: torch.Tensor,
385 num_samples: int,
386 seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
387 generators: Optional[List[torch.Generator]] = None,
388) -> torch.Tensor:
389 if num_samples > 1:
390 # This is equivalent to torch.repeat_interleaved (which also
391 # forces a GPU<->CPU sync).
392 # This allows us to do sampling with replacement by creating
393 # num_samples copies of each row in the tensor, and then
394 # batch sampling the resulting tensor.
395 probs = probs[:, None, :].expand(probs.shape[0], num_samples,
396 probs.shape[1]).contiguous().view(
397 -1, probs.shape[1])
398 q = torch.empty_like(probs)
399 if seq_groups is None:
400 q.exponential_()
401 else:
402 sample_idx = 0
403 for (seq_ids, _), generator in zip(seq_groups, generators):
404 next_sample_idx = sample_idx + len(seq_ids) * num_samples
405 q[sample_idx:next_sample_idx].exponential_(generator=generator)
406 sample_idx = next_sample_idx
407 return probs.div_(q).argmax(dim=1).view(-1, num_samples)
408
409
410def _sample_with_torch(

Callers 1

_sample_with_torchFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected