MCPcopy
hub / github.com/ddbourgin/numpy-ml / KNN

Class KNN

numpy_ml/nonparametric/knn.py:9–101  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

7
8
9class KNN:
10 def __init__(
11 self, k=5, leaf_size=40, classifier=True, metric=None, weights="uniform",
12 ):
13 """
14 A `k`-nearest neighbors (kNN) model relying on a ball tree for efficient
15 computation.
16
17 Parameters
18 ----------
19 k : int
20 The number of neighbors to use during prediction. Default is 5.
21 leaf_size : int
22 The maximum number of datapoints at each leaf in the ball tree.
23 Default is 40.
24 classifier : bool
25 Whether to treat the values in Y as class labels (classifier =
26 True) or real-valued targets (classifier = False). Default is True.
27 metric : :doc:`Distance metric <numpy_ml.utils.distance_metrics>` or None
28 The distance metric to use for computing nearest neighbors. If
29 None, use the :func:`~numpy_ml.utils.distance_metrics.euclidean`
30 metric by default. Default is None.
31 weights : {'uniform', 'distance'}
32 How to weight the predictions from each neighbors. 'uniform'
33 assigns uniform weights to each neighbor, while 'distance' assigns
34 weights proportional to the inverse of the distance from the query
35 point. Default is 'uniform'.
36 """
37 self._ball_tree = BallTree(leaf_size=leaf_size, metric=metric)
38 self.hyperparameters = {
39 "id": "KNN",
40 "k": k,
41 "leaf_size": leaf_size,
42 "classifier": classifier,
43 "metric": str(metric),
44 "weights": weights,
45 }
46
47 def fit(self, X, y):
48 r"""
49 Fit the model to the data and targets in `X` and `y`
50
51 Parameters
52 ----------
53 X : numpy array of shape `(N, M)`
54 An array of `N` examples to generate predictions on.
55 y : numpy array of shape `(N, *)`
56 Targets for the `N` rows in `X`.
57 """
58 if X.ndim != 2:
59 raise Exception("X must be two-dimensional")
60 self._ball_tree.fit(X, y)
61
62 def predict(self, X):
63 r"""
64 Generate predictions for the targets associated with the rows in `X`.
65
66 Parameters

Callers 3

plot_knnFunction · 0.90
test_knn_regressionFunction · 0.90
test_knn_clfFunction · 0.90

Calls

no outgoing calls

Tested by 2

test_knn_regressionFunction · 0.72
test_knn_clfFunction · 0.72