| 195 | |
| 196 | |
| 197 | class BallTree: |
| 198 | def __init__(self, leaf_size=40, metric=None): |
| 199 | """ |
| 200 | A ball tree data structure. |
| 201 | |
| 202 | Notes |
| 203 | ----- |
| 204 | A ball tree is a binary tree in which every node defines a |
| 205 | `D`-dimensional hypersphere ("ball") containing a subset of the points |
| 206 | to be searched. Each internal node of the tree partitions the data |
| 207 | points into two disjoint sets which are associated with different |
| 208 | balls. While the balls themselves may intersect, each point is assigned |
| 209 | to one or the other ball in the partition according to its distance |
| 210 | from the ball's center. Each leaf node in the tree defines a ball and |
| 211 | enumerates all data points inside that ball. |
| 212 | |
| 213 | Parameters |
| 214 | ---------- |
| 215 | leaf_size : int |
| 216 | The maximum number of datapoints at each leaf. Default is 40. |
| 217 | metric : :doc:`Distance metric <numpy_ml.utils.distance_metrics>` or None |
| 218 | The distance metric to use for computing nearest neighbors. If |
| 219 | None, use the :func:`~numpy_ml.utils.distance_metrics.euclidean` |
| 220 | metric. Default is None. |
| 221 | |
| 222 | References |
| 223 | ---------- |
| 224 | .. [1] Omohundro, S. M. (1989). "Five balltree construction algorithms". *ICSI |
| 225 | Technical Report TR-89-063*. |
| 226 | .. [2] Liu, T., Moore, A., & Gray A. (2006). "New algorithms for efficient |
| 227 | high-dimensional nonparametric classification". *J. Mach. Learn. Res., |
| 228 | 7*, 1135-1158. |
| 229 | """ |
| 230 | self.root = None |
| 231 | self.leaf_size = leaf_size |
| 232 | self.metric = metric if metric is not None else euclidean |
| 233 | |
| 234 | def fit(self, X, y=None): |
| 235 | """ |
| 236 | Build a ball tree recursively using the O(M log N) `k`-d construction |
| 237 | algorithm. |
| 238 | |
| 239 | Notes |
| 240 | ----- |
| 241 | Recursively divides data into nodes defined by a centroid `C` and radius |
| 242 | `r` such that each point below the node lies within the hyper-sphere |
| 243 | defined by `C` and `r`. |
| 244 | |
| 245 | Parameters |
| 246 | ---------- |
| 247 | X : :py:class:`ndarray <numpy.ndarray>` of shape `(N, M)` |
| 248 | An array of `N` examples each with `M` features. |
| 249 | y : :py:class:`ndarray <numpy.ndarray>` of shape `(N, \*)` or None |
| 250 | An array of target values / labels associated with the entries in |
| 251 | `X`. Default is None. |
| 252 | """ |
| 253 | centroid, left_X, left_y, right_X, right_y = self._split(X, y) |
| 254 | self.root = BallTreeNode(centroid=centroid) |
no outgoing calls