MCPcopy
hub / github.com/ddbourgin/numpy-ml / plot_regression

Function plot_regression

numpy_ml/plots/lm_plots.py:219–281  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

217
218
219def plot_regression():
220 np.random.seed(12345)
221
222 fig, axes = plt.subplots(4, 4)
223 for i, ax in enumerate(axes.flatten()):
224 n_in = 1
225 n_out = 1
226 n_ex = 50
227 std = np.random.randint(0, 100)
228 intercept = np.random.rand() * np.random.randint(-300, 300)
229 X_train, y_train, X_test, y_test, coefs = random_regression_problem(
230 n_ex, n_in, n_out, intercept=intercept, std=std, seed=i
231 )
232
233 LR = LinearRegression(fit_intercept=True)
234 LR.fit(X_train, y_train)
235 y_pred = LR.predict(X_test)
236 loss = np.mean((y_test - y_pred) ** 2)
237 r2 = r2_score(y_test, y_pred)
238
239 LR_var = BayesianLinearRegressionKnownVariance(
240 mu=np.c_[intercept, coefs][0],
241 sigma=np.sqrt(std),
242 V=None,
243 fit_intercept=True,
244 )
245 LR_var.fit(X_train, y_train)
246 y_pred_var = LR_var.predict(X_test)
247 loss_var = np.mean((y_test - y_pred_var) ** 2)
248 r2_var = r2_score(y_test, y_pred_var)
249
250 LR_novar = BayesianLinearRegressionUnknownVariance(
251 alpha=1, beta=2, mu=np.c_[intercept, coefs][0], V=None, fit_intercept=True,
252 )
253 LR_novar.fit(X_train, y_train)
254 y_pred_novar = LR_novar.predict(X_test)
255 loss_novar = np.mean((y_test - y_pred_novar) ** 2)
256 r2_novar = r2_score(y_test, y_pred_novar)
257
258 xmin = min(X_test) - 0.1 * (max(X_test) - min(X_test))
259 xmax = max(X_test) + 0.1 * (max(X_test) - min(X_test))
260 X_plot = np.linspace(xmin, xmax, 100)
261 y_plot = LR.predict(X_plot)
262 y_plot_var = LR_var.predict(X_plot)
263 y_plot_novar = LR_novar.predict(X_plot)
264
265 ax.scatter(X_test, y_test, marker="x", alpha=0.5)
266 ax.plot(X_plot, y_plot, label="linear regression", alpha=0.5)
267 ax.plot(X_plot, y_plot_var, label="Bayes (w var)", alpha=0.5)
268 ax.plot(X_plot, y_plot_novar, label="Bayes (no var)", alpha=0.5)
269 ax.legend()
270 ax.set_title(
271 "MSE\nLR: {:.2f} Bayes (w var): {:.2f}\nBayes (no var): {:.2f}".format(
272 loss, loss_var, loss_novar
273 )
274 )
275
276 ax.xaxis.set_ticklabels([])

Callers

nothing calls this directly

Calls 10

fitMethod · 0.95
predictMethod · 0.95
fitMethod · 0.95
predictMethod · 0.95
fitMethod · 0.95
predictMethod · 0.95
LinearRegressionClass · 0.90

Tested by

no test coverage detected