MCPcopy
hub / github.com/ddbourgin/numpy-ml / plot_knn

Function plot_knn

numpy_ml/plots/nonparametric_plots.py:104–163  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

102
103
104def plot_knn():
105 np.random.seed(12345)
106 fig, axes = plt.subplots(4, 4)
107 for i, ax in enumerate(axes.flatten()):
108 n_in = 1
109 n_out = 1
110 d = np.random.randint(1, 5)
111 n_ex = np.random.randint(5, 500)
112 std = np.random.randint(0, 1000)
113 intercept = np.random.rand() * np.random.randint(-300, 300)
114 X_train, y_train, X_test, y_test, coefs = random_regression_problem(
115 n_ex, n_in, n_out, d=d, intercept=intercept, std=std, seed=i
116 )
117
118 LR = LinearRegression(fit_intercept=True)
119 LR.fit(X_train, y_train)
120 y_pred = LR.predict(X_test)
121 loss = np.mean((y_test.flatten() - y_pred.flatten()) ** 2)
122
123 knn_1 = KNN(k=1, classifier=False, leaf_size=10, weights="uniform")
124 knn_1.fit(X_train, y_train)
125 y_pred_1 = knn_1.predict(X_test)
126 loss_1 = np.mean((y_test.flatten() - y_pred_1.flatten()) ** 2)
127
128 knn_5 = KNN(k=5, classifier=False, leaf_size=10, weights="uniform")
129 knn_5.fit(X_train, y_train)
130 y_pred_5 = knn_5.predict(X_test)
131 loss_5 = np.mean((y_test.flatten() - y_pred_5.flatten()) ** 2)
132
133 knn_10 = KNN(k=10, classifier=False, leaf_size=10, weights="uniform")
134 knn_10.fit(X_train, y_train)
135 y_pred_10 = knn_10.predict(X_test)
136 loss_10 = np.mean((y_test.flatten() - y_pred_10.flatten()) ** 2)
137
138 xmin = min(X_test) - 0.1 * (max(X_test) - min(X_test))
139 xmax = max(X_test) + 0.1 * (max(X_test) - min(X_test))
140 X_plot = np.linspace(xmin, xmax, 100)
141 y_plot = LR.predict(X_plot)
142 y_plot_1 = knn_1.predict(X_plot)
143 y_plot_5 = knn_5.predict(X_plot)
144 y_plot_10 = knn_10.predict(X_plot)
145
146 ax.scatter(X_test, y_test, alpha=0.5)
147 ax.plot(X_plot, y_plot, label="OLS", alpha=0.5)
148 ax.plot(X_plot, y_plot_1, label="KNN (k=1)", alpha=0.5)
149 ax.plot(X_plot, y_plot_5, label="KNN (k=5)", alpha=0.5)
150 ax.plot(X_plot, y_plot_10, label="KNN (k=10)", alpha=0.5)
151 ax.legend()
152 # ax.set_title(
153 # "MSE\nLR: {:.2f} KR (poly): {:.2f}\nKR (rbf): {:.2f}".format(
154 # loss, loss_poly, loss_rbf
155 # )
156 # )
157
158 ax.xaxis.set_ticklabels([])
159 ax.yaxis.set_ticklabels([])
160
161 plt.tight_layout()

Callers

nothing calls this directly

Calls 7

fitMethod · 0.95
predictMethod · 0.95
fitMethod · 0.95
predictMethod · 0.95
KNNClass · 0.90
LinearRegressionClass · 0.85

Tested by

no test coverage detected