(embeddings, label_path, show=True)
| 28 | |
| 29 | |
| 30 | def plot_embeddings(embeddings, label_path, show=True): |
| 31 | x_data, y_data = read_node_label(str(label_path), skip_head=True) |
| 32 | |
| 33 | embedding_list = np.array([embeddings[node] for node in x_data]) |
| 34 | node_pos = TSNE(n_components=2).fit_transform(embedding_list) |
| 35 | |
| 36 | color_idx = {} |
| 37 | for index, label in enumerate(y_data): |
| 38 | color_idx.setdefault(label[0], []) |
| 39 | color_idx[label[0]].append(index) |
| 40 | |
| 41 | for label, indexes in color_idx.items(): |
| 42 | plt.scatter(node_pos[indexes, 0], node_pos[indexes, 1], label=label) |
| 43 | plt.legend() |
| 44 | if show: |
| 45 | plt.show() |
| 46 | else: |
| 47 | plt.close() |
| 48 | |
| 49 | |
| 50 | def main(smoke=False, show=True): |
no test coverage detected