MCPcopy
hub / github.com/THUDM/CogDL / _call

Method _call

cogdl/pipelines.py:79–120  ·  view source on GitHub ↗
(self, dataset="cora", seed=-1, depth=3, **kwargs)

Source from the content-addressed store, hash-verified

77 super(DatasetVisualPipeline, self).__init__(app, **kwargs)
78
79 def _call(self, dataset="cora", seed=-1, depth=3, **kwargs):
80 if isinstance(dataset, list):
81 dataset = dataset[0]
82 name = dataset
83 dataset = build_dataset_from_name(name)
84 data = dataset[0]
85
86 G = nx.Graph()
87 edge_index = torch.stack(data.edge_index)
88 G.add_edges_from([tuple(edge_index[:, i].numpy()) for i in range(edge_index.shape[1])])
89
90 if seed == -1:
91 seed = random.choice(list(G.nodes()))
92 q = [seed]
93 node_set = set([seed])
94 node_index = {seed: 0}
95 max_index = 1
96 for _ in range(depth):
97 nq = []
98 for x in q:
99 for key in G[x].keys():
100 if key not in node_set:
101 nq.append(key)
102 node_set.add(key)
103 node_index[key] = node_index[x] + 1
104 if len(nq) > 0:
105 max_index += 1
106 q = nq
107
108 cmap = cm.rainbow(np.linspace(0.0, 1.0, max_index))
109
110 for node, index in node_index.items():
111 G.nodes[node]["color"] = cmap[index]
112 G.nodes[node]["size"] = (max_index - index) * 50
113
114 pic_file = f"{name}.png"
115 plt.subplots()
116 plot_network(G.subgraph(list(node_set)), node_style=use_attributes())
117 plt.savefig(pic_file)
118 print(f"Sampled ego network saved to {pic_file}")
119
120 return q
121
122
123class OAGBertInferencePipepline(Pipeline):

Callers

nothing calls this directly

Calls 4

nodesMethod · 0.95
subgraphMethod · 0.95
build_dataset_from_nameFunction · 0.90
keysMethod · 0.45

Tested by

no test coverage detected