r""" Generate predictions for the targets associated with the rows in `X`. Parameters ---------- X : numpy array of shape `(N', M')` An array of `N'` examples to generate predictions on. Returns ------- y : numpy array of shape `(
(self, X)
| 60 | self._ball_tree.fit(X, y) |
| 61 | |
| 62 | def predict(self, X): |
| 63 | r""" |
| 64 | Generate predictions for the targets associated with the rows in `X`. |
| 65 | |
| 66 | Parameters |
| 67 | ---------- |
| 68 | X : numpy array of shape `(N', M')` |
| 69 | An array of `N'` examples to generate predictions on. |
| 70 | |
| 71 | Returns |
| 72 | ------- |
| 73 | y : numpy array of shape `(N', *)` |
| 74 | Predicted targets for the `N'` rows in `X`. |
| 75 | """ |
| 76 | predictions = [] |
| 77 | H = self.hyperparameters |
| 78 | for x in X: |
| 79 | pred = None |
| 80 | nearest = self._ball_tree.nearest_neighbors(H["k"], x) |
| 81 | targets = [n.val for n in nearest] |
| 82 | |
| 83 | if H["classifier"]: |
| 84 | if H["weights"] == "uniform": |
| 85 | # for consistency with sklearn / scipy.stats.mode, return |
| 86 | # the smallest class ID in the event of a tie |
| 87 | counts = Counter(targets).most_common() |
| 88 | pred, _ = sorted(counts, key=lambda x: (-x[1], x[0]))[0] |
| 89 | elif H["weights"] == "distance": |
| 90 | best_score = -np.inf |
| 91 | for label in set(targets): |
| 92 | scores = [1 / n.distance for n in nearest if n.val == label] |
| 93 | pred = label if np.sum(scores) > best_score else pred |
| 94 | else: |
| 95 | if H["weights"] == "uniform": |
| 96 | pred = np.mean(targets) |
| 97 | elif H["weights"] == "distance": |
| 98 | weights = [1 / n.distance for n in nearest] |
| 99 | pred = np.average(targets, weights=weights) |
| 100 | predictions.append(pred) |
| 101 | return np.array(predictions) |