MCPcopy
hub / github.com/dmlc/dgl / sample_subgraph

Method sample_subgraph

examples/pytorch/seal/sampler.py:134–176  ·  view source on GitHub ↗

Args: target_nodes(Tensor): Tensor of two target nodes Returns: subgraph(DGLGraph): subgraph

(self, target_nodes)

Source from the content-addressed store, hash-verified

132 self.num_workers = num_workers
133
134 def sample_subgraph(self, target_nodes):
135 """
136 Args:
137 target_nodes(Tensor): Tensor of two target nodes
138 Returns:
139 subgraph(DGLGraph): subgraph
140 """
141 sample_nodes = [target_nodes]
142 frontiers = target_nodes
143
144 for i in range(self.hop):
145 frontiers = self.graph.out_edges(frontiers)[1]
146 frontiers = torch.unique(frontiers)
147 sample_nodes.append(frontiers)
148
149 sample_nodes = torch.cat(sample_nodes)
150 sample_nodes = torch.unique(sample_nodes)
151 subgraph = dgl.node_subgraph(self.graph, sample_nodes)
152
153 # Each node should have unique node id in the new subgraph
154 u_id = int(
155 torch.nonzero(
156 subgraph.ndata[NID] == int(target_nodes[0]), as_tuple=False
157 )
158 )
159 v_id = int(
160 torch.nonzero(
161 subgraph.ndata[NID] == int(target_nodes[1]), as_tuple=False
162 )
163 )
164
165 # remove link between target nodes in positive subgraphs.
166 if subgraph.has_edges_between(u_id, v_id):
167 link_id = subgraph.edge_ids(u_id, v_id, return_uv=True)[2]
168 subgraph.remove_edges(link_id)
169 if subgraph.has_edges_between(v_id, u_id):
170 link_id = subgraph.edge_ids(v_id, u_id, return_uv=True)[2]
171 subgraph.remove_edges(link_id)
172
173 z = drnl_node_labeling(subgraph, u_id, v_id)
174 subgraph.ndata["z"] = z
175
176 return subgraph
177
178 def _collate(self, batch):
179 batch_graphs, batch_labels = map(list, zip(*batch))

Callers 1

Calls 8

drnl_node_labelingFunction · 0.90
appendMethod · 0.80
nonzeroMethod · 0.80
remove_edgesMethod · 0.80
out_edgesMethod · 0.45
node_subgraphMethod · 0.45
has_edges_betweenMethod · 0.45
edge_idsMethod · 0.45

Tested by 1