MCPcopy
hub / github.com/ddbourgin/numpy-ml / plot_regression

Function plot_regression

numpy_ml/plots/nonparametric_plots.py:35–101  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

33
34
35def plot_regression():
36 np.random.seed(12345)
37 fig, axes = plt.subplots(4, 4)
38 for i, ax in enumerate(axes.flatten()):
39 n_in = 1
40 n_out = 1
41 d = np.random.randint(1, 5)
42 n_ex = np.random.randint(5, 500)
43 std = np.random.randint(0, 1000)
44 intercept = np.random.rand() * np.random.randint(-300, 300)
45 X_train, y_train, X_test, y_test, coefs = random_regression_problem(
46 n_ex, n_in, n_out, d=d, intercept=intercept, std=std, seed=i
47 )
48
49 LR = LinearRegression(fit_intercept=True)
50 LR.fit(X_train, y_train)
51 y_pred = LR.predict(X_test)
52 loss = np.mean((y_test.flatten() - y_pred.flatten()) ** 2)
53
54 d = 3
55 best_loss = np.inf
56 for gamma in np.linspace(1e-10, 1, 100):
57 for c0 in np.linspace(-1, 1000, 100):
58 kernel = "PolynomialKernel(d={}, gamma={}, c0={})".format(d, gamma, c0)
59 KR_poly = KernelRegression(kernel=kernel)
60 KR_poly.fit(X_train, y_train)
61 y_pred_poly = KR_poly.predict(X_test)
62 loss_poly = np.mean((y_test.flatten() - y_pred_poly.flatten()) ** 2)
63 if loss_poly <= best_loss:
64 KR_poly_best = kernel
65 best_loss = loss_poly
66
67 print("Best kernel: {} || loss: {:.4f}".format(KR_poly_best, best_loss))
68 KR_poly = KernelRegression(kernel=KR_poly_best)
69 KR_poly.fit(X_train, y_train)
70
71 KR_rbf = KernelRegression(kernel="RBFKernel(sigma=1)")
72 KR_rbf.fit(X_train, y_train)
73 y_pred_rbf = KR_rbf.predict(X_test)
74 loss_rbf = np.mean((y_test.flatten() - y_pred_rbf.flatten()) ** 2)
75
76 xmin = min(X_test) - 0.1 * (max(X_test) - min(X_test))
77 xmax = max(X_test) + 0.1 * (max(X_test) - min(X_test))
78 X_plot = np.linspace(xmin, xmax, 100)
79 y_plot = LR.predict(X_plot)
80 y_plot_poly = KR_poly.predict(X_plot)
81 y_plot_rbf = KR_rbf.predict(X_plot)
82
83 ax.scatter(X_test, y_test, alpha=0.5)
84 ax.plot(X_plot, y_plot, label="OLS", alpha=0.5)
85 ax.plot(
86 X_plot, y_plot_poly, label="KR (poly kernel, d={})".format(d), alpha=0.5
87 )
88 ax.plot(X_plot, y_plot_rbf, label="KR (rbf kernel)", alpha=0.5)
89 ax.legend()
90 # ax.set_title(
91 # "MSE\nLR: {:.2f} KR (poly): {:.2f}\nKR (rbf): {:.2f}".format(
92 # loss, loss_poly, loss_rbf

Callers

nothing calls this directly

Calls 7

fitMethod · 0.95
predictMethod · 0.95
fitMethod · 0.95
predictMethod · 0.95
KernelRegressionClass · 0.90
LinearRegressionClass · 0.85

Tested by

no test coverage detected