MCPcopy
hub / github.com/ddbourgin/numpy-ml / test_ball_tree

Function test_ball_tree

numpy_ml/tests/test_utils.py:181–217  ·  view source on GitHub ↗
(N=1)

Source from the content-addressed store, hash-verified

179
180
181def test_ball_tree(N=1):
182 np.random.seed(12345)
183 i = 0
184 while i < N:
185 N = np.random.randint(2, 100)
186 M = np.random.randint(2, 100)
187 k = np.random.randint(1, N)
188 ls = np.min([np.random.randint(1, 10), N - 1])
189
190 X = np.random.rand(N, M)
191 BT = BallTree(leaf_size=ls, metric=euclidean)
192 BT.fit(X)
193
194 x = np.random.rand(M)
195 mine = BT.nearest_neighbors(k, x)
196 assert len(mine) == k
197
198 mine_neighb = np.array([n.key for n in mine])
199 mine_dist = np.array([n.distance for n in mine])
200
201 sort_ix = np.argsort(mine_dist)
202 mine_dist = mine_dist[sort_ix]
203 mine_neighb = mine_neighb[sort_ix]
204
205 sk = sk_BallTree(X, leaf_size=ls)
206 theirs_dist, ind = sk.query(x.reshape(1, -1), k=k)
207 sort_ix = np.argsort(theirs_dist.flatten())
208
209 theirs_dist = theirs_dist.flatten()[sort_ix]
210 theirs_neighb = X[ind.flatten()[sort_ix]]
211
212 for j in range(len(theirs_dist)):
213 np.testing.assert_almost_equal(mine_neighb[j], theirs_neighb[j])
214 np.testing.assert_almost_equal(mine_dist[j], theirs_dist[j])
215
216 print("PASSED")
217 i += 1
218
219
220#######################################################################

Callers

nothing calls this directly

Calls 3

fitMethod · 0.95
nearest_neighborsMethod · 0.95
BallTreeClass · 0.90

Tested by

no test coverage detected