Test nearest neighbor searches.
()
| 98 | |
| 99 | |
| 100 | def test_compute_nearest(): |
| 101 | """Test nearest neighbor searches.""" |
| 102 | x = rng.randn(500, 3) |
| 103 | x /= np.sqrt(np.sum(x**2, axis=1))[:, None] |
| 104 | nn_true = rng.permutation(np.arange(500, dtype=np.int64))[:20] |
| 105 | y = x[nn_true] |
| 106 | |
| 107 | nn1 = _compute_nearest(x, y, method="BallTree") |
| 108 | nn2 = _compute_nearest(x, y, method="KDTree") |
| 109 | nn3 = _compute_nearest(x, y, method="cdist") |
| 110 | assert_array_equal(nn_true, nn1) |
| 111 | assert_array_equal(nn_true, nn2) |
| 112 | assert_array_equal(nn_true, nn3) |
| 113 | |
| 114 | # test distance support |
| 115 | nnn1 = _compute_nearest(x, y, method="BallTree", return_dists=True) |
| 116 | nnn2 = _compute_nearest(x, y, method="KDTree", return_dists=True) |
| 117 | nnn3 = _compute_nearest(x, y, method="cdist", return_dists=True) |
| 118 | assert_array_equal(nnn1[0], nn_true) |
| 119 | assert_array_equal(nnn1[1], np.zeros_like(nn1)) # all dists should be 0 |
| 120 | assert_equal(len(nnn1), len(nnn2)) |
| 121 | for nn1, nn2, nn3 in zip(nnn1, nnn2, nnn3): |
| 122 | assert_array_equal(nn1, nn2) |
| 123 | assert_array_equal(nn1, nn3) |
| 124 | |
| 125 | |
| 126 | @testing.requires_testing_data |
nothing calls this directly
no test coverage detected