(self)
| 348 | self.assertTrue(equal_distribution(Tensor.kaiming_normal, lambda x: torch.nn.init.kaiming_normal_(torch.empty(x)), shape=shape)) |
| 349 | |
| 350 | def test_multinomial(self): |
| 351 | self.assertRaises(AssertionError, lambda: Tensor(2).multinomial(1, replacement=False)) |
| 352 | self.assertRaises(AssertionError, lambda: Tensor([1, 9]).multinomial(0, replacement=False)) |
| 353 | def _check_with_torch(w, num_samples, replacement): |
| 354 | tiny_res = Tensor(w).multinomial(num_samples, replacement=replacement) |
| 355 | torch_res = torch.tensor(w).multinomial(num_samples, replacement=replacement) |
| 356 | self.assertEqual(tiny_res.shape, torch_res.shape) |
| 357 | if torch_res.ndim == 1: |
| 358 | tiny_res = tiny_res.unsqueeze(0) |
| 359 | torch_res = torch_res.unsqueeze(0) |
| 360 | for i in range(torch_res.shape[0]): |
| 361 | self.assertTrue(equal_distribution(lambda *_: tiny_res[i], lambda _: torch_res[i])) |
| 362 | _check_with_torch(w=[0.231, 0., 1., 0.5], num_samples=300, replacement=True) |
| 363 | _check_with_torch(w=[[0.2, 0.8]], num_samples=300, replacement=True) # 2D but only 1 row |
| 364 | _check_with_torch(w=[[0.453, 0., 1., 0.81], [0.1, 0.8, 0., 0.1]], num_samples=300, replacement=True) |
| 365 | # no-replacement |
| 366 | w = [0.1, 0.9] |
| 367 | self.assertRaises(AssertionError, lambda: Tensor(w).multinomial(100, replacement=False)) |
| 368 | |
| 369 | @TinyJit |
| 370 | def sample_one(): return Tensor(w).multinomial(1, replacement=False).realize() |
| 371 | |
| 372 | tiny_samples = [sample_one().item() for _ in range(400)] |
| 373 | torch_samples = [torch.tensor(w).multinomial(1, replacement=False).item() for _ in range(400)] |
| 374 | self.assertTrue(equal_distribution(lambda *_: Tensor(tiny_samples), lambda _: torch.tensor(torch_samples))) |
| 375 | |
| 376 | w = list(range(32)) |
| 377 | s1 = Tensor(w).multinomial(5, replacement=False).numpy() |
| 378 | self.assertEqual(len(set(s1.tolist())), 5) |
| 379 | s2 = Tensor(w).multinomial(5, replacement=False).numpy() |
| 380 | self.assertFalse(np.array_equal(s1, s2)) |
| 381 | full = Tensor(w).multinomial(len(w), replacement=False).numpy() |
| 382 | self.assertEqual(sorted(full.tolist()), w) |
| 383 | |
| 384 | w = [0.1, 0.2, 0.3, 0.4] |
| 385 | @TinyJit |
| 386 | def sample_three(): return Tensor(w).multinomial(3, replacement=False).realize() |
| 387 | |
| 388 | tiny_draws = np.array([sample_three().numpy() for _ in range(400)]) |
| 389 | torch_draws = np.array([torch.tensor(w).multinomial(3, replacement=False).numpy() for _ in range(400)]) |
| 390 | for pos in range(3): |
| 391 | self.assertTrue(equal_distribution(lambda *_: Tensor(tiny_draws[:, pos]), lambda _: torch.tensor(torch_draws[:, pos]))) |
| 392 | |
| 393 | @unittest.skip("this test is flaky") |
| 394 | def test_multinomial_counterexample(self): |
nothing calls this directly
no test coverage detected