Test sampling when given a large flat distribution.
(randomized, single_jitter)
| 425 | |
| 426 | |
| 427 | def impl_test_sample_large_flat(randomized, single_jitter): |
| 428 | """Test sampling when given a large flat distribution.""" |
| 429 | num_samples = 100 |
| 430 | num_bins = 100000 |
| 431 | bins = torch.arange(num_bins) |
| 432 | weights = np.ones(len(bins) - 1) |
| 433 | samples = importance_sampling( |
| 434 | bins[None], |
| 435 | torch_softmax(torch_maximum(1e-15, weights[None]).log()), |
| 436 | num_samples, |
| 437 | perturb=randomized, |
| 438 | single_jitter=single_jitter, |
| 439 | )[0] |
| 440 | # All samples should be within the range of the bins. |
| 441 | assert_true(torch.all(samples >= bins[0])) |
| 442 | assert_true(torch.all(samples <= bins[-1])) |
| 443 | |
| 444 | # Samples modded by their bin index should resemble a uniform distribution. |
| 445 | samples_mod = torch.fmod(samples, 1) |
| 446 | assert_true( |
| 447 | sp.stats.kstest(samples_mod, 'uniform', (0, 1)).statistic <= 0.2) |
| 448 | |
| 449 | # All samples should collectively resemble a uniform distribution. |
| 450 | assert_true( |
| 451 | sp.stats.kstest(samples, 'uniform', (bins[0], bins[-1])).statistic <= 0.2) |
| 452 | |
| 453 | |
| 454 | test_sample_large_flat_deterministic = partial(impl_test_sample_large_flat, False, False) |
nothing calls this directly
no test coverage detected