(N=1)
| 179 | |
| 180 | |
| 181 | def test_ball_tree(N=1): |
| 182 | np.random.seed(12345) |
| 183 | i = 0 |
| 184 | while i < N: |
| 185 | N = np.random.randint(2, 100) |
| 186 | M = np.random.randint(2, 100) |
| 187 | k = np.random.randint(1, N) |
| 188 | ls = np.min([np.random.randint(1, 10), N - 1]) |
| 189 | |
| 190 | X = np.random.rand(N, M) |
| 191 | BT = BallTree(leaf_size=ls, metric=euclidean) |
| 192 | BT.fit(X) |
| 193 | |
| 194 | x = np.random.rand(M) |
| 195 | mine = BT.nearest_neighbors(k, x) |
| 196 | assert len(mine) == k |
| 197 | |
| 198 | mine_neighb = np.array([n.key for n in mine]) |
| 199 | mine_dist = np.array([n.distance for n in mine]) |
| 200 | |
| 201 | sort_ix = np.argsort(mine_dist) |
| 202 | mine_dist = mine_dist[sort_ix] |
| 203 | mine_neighb = mine_neighb[sort_ix] |
| 204 | |
| 205 | sk = sk_BallTree(X, leaf_size=ls) |
| 206 | theirs_dist, ind = sk.query(x.reshape(1, -1), k=k) |
| 207 | sort_ix = np.argsort(theirs_dist.flatten()) |
| 208 | |
| 209 | theirs_dist = theirs_dist.flatten()[sort_ix] |
| 210 | theirs_neighb = X[ind.flatten()[sort_ix]] |
| 211 | |
| 212 | for j in range(len(theirs_dist)): |
| 213 | np.testing.assert_almost_equal(mine_neighb[j], theirs_neighb[j]) |
| 214 | np.testing.assert_almost_equal(mine_dist[j], theirs_dist[j]) |
| 215 | |
| 216 | print("PASSED") |
| 217 | i += 1 |
| 218 | |
| 219 | |
| 220 | ####################################################################### |
nothing calls this directly
no test coverage detected