| 5 | import random |
| 6 | |
| 7 | def visualize( |
| 8 | x, |
| 9 | y, |
| 10 | ax=None, |
| 11 | title=None, |
| 12 | draw_legend=True, |
| 13 | draw_centers=False, |
| 14 | draw_cluster_labels=False, |
| 15 | colors=None, |
| 16 | legend_kwargs=None, |
| 17 | label_order=None, |
| 18 | **kwargs |
| 19 | ): |
| 20 | |
| 21 | if ax is None: |
| 22 | _, ax = matplotlib.pyplot.subplots(figsize=(10, 8)) |
| 23 | |
| 24 | if title is not None: |
| 25 | ax.set_title(title) |
| 26 | |
| 27 | plot_params = {"alpha": kwargs.get("alpha", 0.6), "s": kwargs.get("s", 1)} |
| 28 | |
| 29 | # Create main plot |
| 30 | if label_order is not None: |
| 31 | assert all(np.isin(np.unique(y), label_order)) |
| 32 | classes = [l for l in label_order if l in np.unique(y)] |
| 33 | else: |
| 34 | classes = np.unique(y) |
| 35 | if colors is None: |
| 36 | default_colors = matplotlib.rcParams["axes.prop_cycle"] |
| 37 | colors = {k: v["color"] for k, v in zip(classes, default_colors())} |
| 38 | |
| 39 | point_colors = list(map(colors.get, y)) |
| 40 | |
| 41 | ax.scatter(x[:, 0], x[:, 1], c=point_colors, rasterized=True, **plot_params) |
| 42 | |
| 43 | # Plot mediods |
| 44 | if draw_centers: |
| 45 | centers = [] |
| 46 | for yi in classes: |
| 47 | mask = yi == y |
| 48 | centers.append(np.median(x[mask, :2], axis=0)) |
| 49 | centers = np.array(centers) |
| 50 | |
| 51 | center_colors = list(map(colors.get, classes)) |
| 52 | ax.scatter( |
| 53 | centers[:, 0], centers[:, 1], c=center_colors, s=48, alpha=1, edgecolor="k" |
| 54 | ) |
| 55 | |
| 56 | # Draw mediod labels |
| 57 | if draw_cluster_labels: |
| 58 | for idx, label in enumerate(classes): |
| 59 | ax.text( |
| 60 | centers[idx, 0], |
| 61 | centers[idx, 1] + 2.2, |
| 62 | label, |
| 63 | fontsize=kwargs.get("fontsize", 6), |
| 64 | horizontalalignment="center", |