(embeddings, label_path, show=True)
| 34 | |
| 35 | |
| 36 | def plot_embeddings(embeddings, label_path, show=True): |
| 37 | x_data, y_data = read_node_label(str(label_path)) |
| 38 | |
| 39 | embedding_list = np.array([embeddings[node] for node in x_data]) |
| 40 | node_pos = TSNE(n_components=2).fit_transform(embedding_list) |
| 41 | |
| 42 | color_idx = {} |
| 43 | for index, label in enumerate(y_data): |
| 44 | color_idx.setdefault(label[0], []) |
| 45 | color_idx[label[0]].append(index) |
| 46 | |
| 47 | for label, indexes in color_idx.items(): |
| 48 | plt.scatter(node_pos[indexes, 0], node_pos[indexes, 1], label=label) |
| 49 | plt.legend() |
| 50 | if show: |
| 51 | plt.show() |
| 52 | else: |
| 53 | plt.close() |
| 54 | |
| 55 | |
| 56 | def main(smoke=False, show=True): |
no test coverage detected