Plot contour for 2D data.
(self, data=None, ax=None, holdon=False)
| 136 | return self.weights * likelihood |
| 137 | |
| 138 | def plot(self, data=None, ax=None, holdon=False): |
| 139 | """Plot contour for 2D data.""" |
| 140 | if not (len(self.X.shape) == 2 and self.X.shape[1] == 2): |
| 141 | raise AttributeError("Only support for visualizing 2D data.") |
| 142 | |
| 143 | if ax is None: |
| 144 | _, ax = plt.subplots() |
| 145 | |
| 146 | if data is None: |
| 147 | data = self.X |
| 148 | assignments = self.assignments |
| 149 | else: |
| 150 | assignments = self.predict(data) |
| 151 | |
| 152 | COLOR = "bgrcmyk" |
| 153 | cmap = lambda assignment: COLOR[int(assignment) % len(COLOR)] |
| 154 | |
| 155 | # generate grid |
| 156 | delta = 0.025 |
| 157 | margin = 0.2 |
| 158 | xmax, ymax = self.X.max(axis=0) + margin |
| 159 | xmin, ymin = self.X.min(axis=0) - margin |
| 160 | axis_X, axis_Y = np.meshgrid( |
| 161 | np.arange(xmin, xmax, delta), np.arange(ymin, ymax, delta) |
| 162 | ) |
| 163 | |
| 164 | def grid_gaussian_pdf(mean, cov): |
| 165 | grid_array = np.array(list(zip(axis_X.flatten(), axis_Y.flatten()))) |
| 166 | return multivariate_normal.pdf(grid_array, mean, cov).reshape(axis_X.shape) |
| 167 | |
| 168 | # plot scatters |
| 169 | if assignments is None: |
| 170 | c = None |
| 171 | else: |
| 172 | c = [cmap(assignment) for assignment in assignments] |
| 173 | ax.scatter(data[:, 0], data[:, 1], c=c) |
| 174 | |
| 175 | # plot contours |
| 176 | for assignment in range(self.K): |
| 177 | ax.contour( |
| 178 | axis_X, |
| 179 | axis_Y, |
| 180 | grid_gaussian_pdf(self.means[assignment], self.covs[assignment]), |
| 181 | colors=cmap(assignment), |
| 182 | ) |
| 183 | |
| 184 | if not holdon: |
| 185 | plt.show() |
no test coverage detected