MCPcopy
hub / github.com/THUDM/CogDL / plot_graph

Function plot_graph

scripts/display_data.py:14–74  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

12
13
14def plot_graph(args):
15 if not isinstance(args.dataset, list):
16 args.dataset = [args.dataset]
17
18 for name in args.dataset:
19 dataset = build_dataset_from_name(name)
20 data = dataset[0]
21
22 depth = args.depth
23 pic_file = osp.join(args.save_dir, f"display_{name}.png")
24
25 col_names = [
26 "Dataset",
27 "#nodes",
28 "#edges",
29 "#features",
30 "#classes",
31 "#labeled data",
32 ]
33 tab_data = [
34 [
35 name,
36 data.x.shape[0],
37 data.edge_index.shape[1],
38 data.x.shape[1],
39 len(set(data.y.numpy())),
40 sum(data.train_mask.numpy()),
41 ]
42 ]
43 print(tabulate(tab_data, headers=col_names, tablefmt="psql"))
44
45 G = nx.Graph()
46 G.add_edges_from([tuple(data.edge_index[:, i].numpy()) for i in range(data.edge_index.shape[1])])
47
48 s = random.choice(list(G.nodes()))
49 q = [s]
50 node_set = set([s])
51 node_index = {s: 0}
52 max_index = 1
53 for _ in range(depth):
54 nq = []
55 for x in q:
56 for key in G[x].keys():
57 if key not in node_set:
58 nq.append(key)
59 node_set.add(key)
60 node_index[key] = node_index[x] + 1
61 if len(nq) > 0:
62 max_index += 1
63 q = nq
64
65 cmap = cm.rainbow(np.linspace(0.0, 1.0, max_index))
66
67 for node, index in node_index.items():
68 G.nodes[node]["color"] = cmap[index]
69 G.nodes[node]["size"] = (max_index - index) * 50
70
71 fig, ax = plt.subplots()

Callers 1

display_data.pyFile · 0.85

Calls 4

nodesMethod · 0.95
subgraphMethod · 0.95
build_dataset_from_nameFunction · 0.90
keysMethod · 0.45

Tested by

no test coverage detected