This function used to plot predictions and display the graph
(
training_data: np.mat,
predictions: np.ndarray,
col_x: np.ndarray,
col_y: np.ndarray,
cola_name: str,
colb_name: str,
)
| 86 | |
| 87 | |
| 88 | def plot_preds( |
| 89 | training_data: np.mat, |
| 90 | predictions: np.ndarray, |
| 91 | col_x: np.ndarray, |
| 92 | col_y: np.ndarray, |
| 93 | cola_name: str, |
| 94 | colb_name: str, |
| 95 | ) -> plt.plot: |
| 96 | """ |
| 97 | This function used to plot predictions and display the graph |
| 98 | """ |
| 99 | xsort = training_data.copy() |
| 100 | xsort.sort(axis=0) |
| 101 | plt.scatter(col_x, col_y, color="blue") |
| 102 | plt.plot( |
| 103 | xsort[:, 1], |
| 104 | predictions[training_data[:, 1].argsort(0)], |
| 105 | color="yellow", |
| 106 | linewidth=5, |
| 107 | ) |
| 108 | plt.title("Local Weighted Regression") |
| 109 | plt.xlabel(cola_name) |
| 110 | plt.ylabel(colb_name) |
| 111 | plt.show() |
| 112 | |
| 113 | |
| 114 | if __name__ == "__main__": |
no test coverage detected