Test against torch.searchsorted, which behaves similarly to ours.
()
| 127 | |
| 128 | |
| 129 | def test_searchsorted_reference(): |
| 130 | """Test against torch.searchsorted, which behaves similarly to ours.""" |
| 131 | eps = 1e-7 |
| 132 | n = 30 |
| 133 | m = 40 |
| 134 | |
| 135 | # Generate query points in [eps, 1-eps]. |
| 136 | v = torch.rand([n]) * (1 - eps - eps) + eps |
| 137 | |
| 138 | # Generate sorted reference points that span all of [0, 1]. |
| 139 | a, _ = torch.sort(torch.rand([m])) |
| 140 | a = torch.cat([torch.tensor([0.]), a, torch.tensor([1.])]) |
| 141 | _, idx_hi = searchsorted(a, v) |
| 142 | assert_true((np.array_equal(np.searchsorted(a, v), idx_hi.numpy()))) |
| 143 | |
| 144 | |
| 145 | def test_searchsorted(): |
nothing calls this directly
no test coverage detected