Find nearest neighbors. Parameters ---------- xhs : array, shape=(n_samples, n_dim) Points of data set. rr : array, shape=(n_query, n_dim) Points to find nearest neighbors for. method : str The query method. If scikit-learn and scipy<1.0 are installed,
(xhs, rr, method="BallTree", return_dists=False)
| 616 | |
| 617 | |
| 618 | def _compute_nearest(xhs, rr, method="BallTree", return_dists=False): |
| 619 | """Find nearest neighbors. |
| 620 | |
| 621 | Parameters |
| 622 | ---------- |
| 623 | xhs : array, shape=(n_samples, n_dim) |
| 624 | Points of data set. |
| 625 | rr : array, shape=(n_query, n_dim) |
| 626 | Points to find nearest neighbors for. |
| 627 | method : str |
| 628 | The query method. If scikit-learn and scipy<1.0 are installed, |
| 629 | it will fall back to the slow brute-force search. |
| 630 | return_dists : bool |
| 631 | If True, return associated distances. |
| 632 | |
| 633 | Returns |
| 634 | ------- |
| 635 | nearest : array, shape=(n_query,) |
| 636 | Index of nearest neighbor in xhs for every point in rr. |
| 637 | distances : array, shape=(n_query,) |
| 638 | The distances. Only returned if return_dists is True. |
| 639 | """ |
| 640 | if xhs.size == 0 or rr.size == 0: |
| 641 | if return_dists: |
| 642 | return np.array([], int), np.array([]) |
| 643 | return np.array([], int) |
| 644 | tree = _DistanceQuery(xhs, method=method) |
| 645 | out = tree.query(rr) |
| 646 | return out[::-1] if return_dists else out[1] |
| 647 | |
| 648 | |
| 649 | def _safe_query(rr, func, reduce=False, **kwargs): |