MCPcopy
hub / github.com/tinygrad/tinygrad / multinomial

Method multinomial

tinygrad/tensor.py:798–823  ·  view source on GitHub ↗

Returns a tensor with `num_samples` indices sampled from a multinomial distribution weighted by `self`. ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor([1, 2, 3, 4]) print(t.multinomial(20, replacement=True).numpy()) `

(self:Tensor, num_samples:int = 1, replacement:bool = False)

Source from the content-addressed store, hash-verified

796 return Tensor.rand(n, device=device, **kwargs).argsort().cast(dtype)
797
798 def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor:
799 """
800 Returns a tensor with `num_samples` indices sampled from a multinomial distribution weighted by `self`.
801
802 ```python exec="true" source="above" session="tensor" result="python"
803 Tensor.manual_seed(42)
804 t = Tensor([1, 2, 3, 4])
805 print(t.multinomial(20, replacement=True).numpy())
806 ```
807 ```python exec="true" source="above" session="tensor" result="python"
808 Tensor.manual_seed(42)
809 t = Tensor([1, 2, 3, 4])
810 print(t.multinomial(3, replacement=False).numpy())
811 ```
812 """
813 assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
814 weight = self.unsqueeze(0) if self.ndim == 1 else self
815 assert replacement or num_samples <= weight.shape[1], "no replacement samples must not exceed population size"
816 if replacement or num_samples == 1:
817 cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
818 unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1).to(self.device)
819 indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
820 else:
821 # Efraimidis–Spirakis
822 indices = (weight.rand_like(dtype=dtypes.float32).log2() / weight).topk(num_samples, dim=1)[1]
823 return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
824
825 # ***** toposort and backward pass *****
826

Callers 11

sampleFunction · 0.80
test_multinomialMethod · 0.80
_check_with_torchMethod · 0.80
sample_oneMethod · 0.80
sample_threeMethod · 0.80
forwardMethod · 0.80
generateFunction · 0.80
get_actionFunction · 0.80
generateMethod · 0.80

Calls 13

unsqueezeMethod · 0.80
floatMethod · 0.80
cumsumMethod · 0.80
randMethod · 0.80
permuteMethod · 0.80
sumMethod · 0.80
expandMethod · 0.80
topkMethod · 0.80
log2Method · 0.80
rand_likeMethod · 0.80
squeezeMethod · 0.80
toMethod · 0.45

Tested by 6

test_multinomialMethod · 0.64
_check_with_torchMethod · 0.64
sample_oneMethod · 0.64
sample_threeMethod · 0.64