This is a sampling operation and an equivalent of torch.distributions.Categorical.sample() i.e. given a probability distribution tensor, it samples an index of that tensor. See: https://pytorch.org/docs/stable/distributions.html#torch.distributions.categorical.Categorical.sample NOT
(probs: Tensor, rand_data: Tensor = None)
| 1434 | |
| 1435 | |
| 1436 | def categorical_sample(probs: Tensor, rand_data: Tensor = None) -> Tensor: |
| 1437 | ''' |
| 1438 | This is a sampling operation and an equivalent of torch.distributions.Categorical.sample() |
| 1439 | i.e. given a probability distribution tensor, it samples an index of that tensor. |
| 1440 | See: https://pytorch.org/docs/stable/distributions.html#torch.distributions.categorical.Categorical.sample |
| 1441 | NOTE: This assumes that the given probabilities are **not** normalized. |
| 1442 | |
| 1443 | Parameters: |
| 1444 | probs: Tensor |
| 1445 | A 1-D floating point tensor representing the probability distributions. |
| 1446 | rand_data: Tensor (optional) |
| 1447 | A random tensor of same shape as `probs` tensor. |
| 1448 | If not provided, this function will add a rand() op to generate it and use for sampling. |
| 1449 | Returns: |
| 1450 | A tensor containing a single index of the `probs` tensor representing the sample. |
| 1451 | ''' |
| 1452 | probs = probs / sum(probs, dim=-1, keepdim=True) |
| 1453 | rand_shape = [] |
| 1454 | assert probs.ndim() > 0 |
| 1455 | for i in range(probs.ndim() - 1): |
| 1456 | rand_shape.append(shape(probs, i)) |
| 1457 | rand_shape = concat(rand_shape) |
| 1458 | if rand_data is None: |
| 1459 | rand_data = rand(rand_shape, low=0, high=1, dtype=probs.dtype) |
| 1460 | assert rand_shape == shape(rand_data) |
| 1461 | rand_data = expand(unsqueeze(rand_data, -1), shape(probs)) |
| 1462 | cum_probs = cumsum(probs, dim=-1) |
| 1463 | cmp = cast(cum_probs >= rand_data, probs.dtype) |
| 1464 | samples = argmax(cmp, dim=-1) |
| 1465 | return samples |
| 1466 | |
| 1467 | |
| 1468 | class Conditional: |