Wrapper for fast distance queries.
| 655 | |
| 656 | |
| 657 | class _DistanceQuery: |
| 658 | """Wrapper for fast distance queries.""" |
| 659 | |
| 660 | def __init__(self, xhs, method="BallTree"): |
| 661 | assert method in ("BallTree", "KDTree", "cdist") |
| 662 | |
| 663 | # Fastest for our problems: balltree |
| 664 | if method == "BallTree": |
| 665 | try: |
| 666 | from sklearn.neighbors import BallTree |
| 667 | except ImportError: |
| 668 | logger.info( |
| 669 | "Nearest-neighbor searches will be significantly " |
| 670 | "faster if scikit-learn is installed." |
| 671 | ) |
| 672 | method = "KDTree" |
| 673 | else: |
| 674 | self.query = partial( |
| 675 | _safe_query, |
| 676 | func=BallTree(xhs).query, |
| 677 | reduce=True, |
| 678 | return_distance=True, |
| 679 | ) |
| 680 | |
| 681 | # Then KDTree |
| 682 | if method == "KDTree": |
| 683 | from scipy.spatial import KDTree |
| 684 | |
| 685 | self.query = KDTree(xhs).query |
| 686 | |
| 687 | # Then the worst: cdist |
| 688 | if method == "cdist": |
| 689 | self.query = _CDist(xhs).query |
| 690 | |
| 691 | self.data = xhs |
| 692 | |
| 693 | |
| 694 | @verbose |
no outgoing calls
no test coverage detected