(args)
| 12 | |
| 13 | |
| 14 | def plot_graph(args): |
| 15 | if not isinstance(args.dataset, list): |
| 16 | args.dataset = [args.dataset] |
| 17 | |
| 18 | for name in args.dataset: |
| 19 | dataset = build_dataset_from_name(name) |
| 20 | data = dataset[0] |
| 21 | |
| 22 | depth = args.depth |
| 23 | pic_file = osp.join(args.save_dir, f"display_{name}.png") |
| 24 | |
| 25 | col_names = [ |
| 26 | "Dataset", |
| 27 | "#nodes", |
| 28 | "#edges", |
| 29 | "#features", |
| 30 | "#classes", |
| 31 | "#labeled data", |
| 32 | ] |
| 33 | tab_data = [ |
| 34 | [ |
| 35 | name, |
| 36 | data.x.shape[0], |
| 37 | data.edge_index.shape[1], |
| 38 | data.x.shape[1], |
| 39 | len(set(data.y.numpy())), |
| 40 | sum(data.train_mask.numpy()), |
| 41 | ] |
| 42 | ] |
| 43 | print(tabulate(tab_data, headers=col_names, tablefmt="psql")) |
| 44 | |
| 45 | G = nx.Graph() |
| 46 | G.add_edges_from([tuple(data.edge_index[:, i].numpy()) for i in range(data.edge_index.shape[1])]) |
| 47 | |
| 48 | s = random.choice(list(G.nodes())) |
| 49 | q = [s] |
| 50 | node_set = set([s]) |
| 51 | node_index = {s: 0} |
| 52 | max_index = 1 |
| 53 | for _ in range(depth): |
| 54 | nq = [] |
| 55 | for x in q: |
| 56 | for key in G[x].keys(): |
| 57 | if key not in node_set: |
| 58 | nq.append(key) |
| 59 | node_set.add(key) |
| 60 | node_index[key] = node_index[x] + 1 |
| 61 | if len(nq) > 0: |
| 62 | max_index += 1 |
| 63 | q = nq |
| 64 | |
| 65 | cmap = cm.rainbow(np.linspace(0.0, 1.0, max_index)) |
| 66 | |
| 67 | for node, index in node_index.items(): |
| 68 | G.nodes[node]["color"] = cmap[index] |
| 69 | G.nodes[node]["size"] = (max_index - index) * 50 |
| 70 | |
| 71 | fig, ax = plt.subplots() |
no test coverage detected