MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / categorical_sample

Function categorical_sample

tensorrt_llm/functional.py:1436–1465  ·  view source on GitHub ↗

This is a sampling operation and an equivalent of torch.distributions.Categorical.sample() i.e. given a probability distribution tensor, it samples an index of that tensor. See: https://pytorch.org/docs/stable/distributions.html#torch.distributions.categorical.Categorical.sample NOT

(probs: Tensor, rand_data: Tensor = None)

Source from the content-addressed store, hash-verified

1434
1435
1436def categorical_sample(probs: Tensor, rand_data: Tensor = None) -> Tensor:
1437 '''
1438 This is a sampling operation and an equivalent of torch.distributions.Categorical.sample()
1439 i.e. given a probability distribution tensor, it samples an index of that tensor.
1440 See: https://pytorch.org/docs/stable/distributions.html#torch.distributions.categorical.Categorical.sample
1441 NOTE: This assumes that the given probabilities are **not** normalized.
1442
1443 Parameters:
1444 probs: Tensor
1445 A 1-D floating point tensor representing the probability distributions.
1446 rand_data: Tensor (optional)
1447 A random tensor of same shape as `probs` tensor.
1448 If not provided, this function will add a rand() op to generate it and use for sampling.
1449 Returns:
1450 A tensor containing a single index of the `probs` tensor representing the sample.
1451 '''
1452 probs = probs / sum(probs, dim=-1, keepdim=True)
1453 rand_shape = []
1454 assert probs.ndim() > 0
1455 for i in range(probs.ndim() - 1):
1456 rand_shape.append(shape(probs, i))
1457 rand_shape = concat(rand_shape)
1458 if rand_data is None:
1459 rand_data = rand(rand_shape, low=0, high=1, dtype=probs.dtype)
1460 assert rand_shape == shape(rand_data)
1461 rand_data = expand(unsqueeze(rand_data, -1), shape(probs))
1462 cum_probs = cumsum(probs, dim=-1)
1463 cmp = cast(cum_probs >= rand_data, probs.dtype)
1464 samples = argmax(cmp, dim=-1)
1465 return samples
1466
1467
1468class Conditional:

Callers 1

_fwd_helperMethod · 0.90

Calls 11

sumFunction · 0.85
concatFunction · 0.85
randFunction · 0.85
expandFunction · 0.85
unsqueezeFunction · 0.85
cumsumFunction · 0.85
castFunction · 0.85
argmaxFunction · 0.85
shapeFunction · 0.70
ndimMethod · 0.45
appendMethod · 0.45

Tested by

no test coverage detected