MCPcopy
hub / github.com/tinygrad/tinygrad / test_multinomial

Method test_multinomial

test/backend/test_randomness.py:350–391  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

348 self.assertTrue(equal_distribution(Tensor.kaiming_normal, lambda x: torch.nn.init.kaiming_normal_(torch.empty(x)), shape=shape))
349
350 def test_multinomial(self):
351 self.assertRaises(AssertionError, lambda: Tensor(2).multinomial(1, replacement=False))
352 self.assertRaises(AssertionError, lambda: Tensor([1, 9]).multinomial(0, replacement=False))
353 def _check_with_torch(w, num_samples, replacement):
354 tiny_res = Tensor(w).multinomial(num_samples, replacement=replacement)
355 torch_res = torch.tensor(w).multinomial(num_samples, replacement=replacement)
356 self.assertEqual(tiny_res.shape, torch_res.shape)
357 if torch_res.ndim == 1:
358 tiny_res = tiny_res.unsqueeze(0)
359 torch_res = torch_res.unsqueeze(0)
360 for i in range(torch_res.shape[0]):
361 self.assertTrue(equal_distribution(lambda *_: tiny_res[i], lambda _: torch_res[i]))
362 _check_with_torch(w=[0.231, 0., 1., 0.5], num_samples=300, replacement=True)
363 _check_with_torch(w=[[0.2, 0.8]], num_samples=300, replacement=True) # 2D but only 1 row
364 _check_with_torch(w=[[0.453, 0., 1., 0.81], [0.1, 0.8, 0., 0.1]], num_samples=300, replacement=True)
365 # no-replacement
366 w = [0.1, 0.9]
367 self.assertRaises(AssertionError, lambda: Tensor(w).multinomial(100, replacement=False))
368
369 @TinyJit
370 def sample_one(): return Tensor(w).multinomial(1, replacement=False).realize()
371
372 tiny_samples = [sample_one().item() for _ in range(400)]
373 torch_samples = [torch.tensor(w).multinomial(1, replacement=False).item() for _ in range(400)]
374 self.assertTrue(equal_distribution(lambda *_: Tensor(tiny_samples), lambda _: torch.tensor(torch_samples)))
375
376 w = list(range(32))
377 s1 = Tensor(w).multinomial(5, replacement=False).numpy()
378 self.assertEqual(len(set(s1.tolist())), 5)
379 s2 = Tensor(w).multinomial(5, replacement=False).numpy()
380 self.assertFalse(np.array_equal(s1, s2))
381 full = Tensor(w).multinomial(len(w), replacement=False).numpy()
382 self.assertEqual(sorted(full.tolist()), w)
383
384 w = [0.1, 0.2, 0.3, 0.4]
385 @TinyJit
386 def sample_three(): return Tensor(w).multinomial(3, replacement=False).realize()
387
388 tiny_draws = np.array([sample_three().numpy() for _ in range(400)])
389 torch_draws = np.array([torch.tensor(w).multinomial(3, replacement=False).numpy() for _ in range(400)])
390 for pos in range(3):
391 self.assertTrue(equal_distribution(lambda *_: Tensor(tiny_draws[:, pos]), lambda _: torch.tensor(torch_draws[:, pos])))
392
393 @unittest.skip("this test is flaky")
394 def test_multinomial_counterexample(self):

Callers

nothing calls this directly

Calls 7

TensorClass · 0.90
equal_distributionFunction · 0.85
multinomialMethod · 0.80
itemMethod · 0.80
tensorMethod · 0.80
numpyMethod · 0.45
tolistMethod · 0.45

Tested by

no test coverage detected