Args: target_nodes(Tensor): Tensor of two target nodes Returns: subgraph(DGLGraph): subgraph
(self, target_nodes)
| 132 | self.num_workers = num_workers |
| 133 | |
| 134 | def sample_subgraph(self, target_nodes): |
| 135 | """ |
| 136 | Args: |
| 137 | target_nodes(Tensor): Tensor of two target nodes |
| 138 | Returns: |
| 139 | subgraph(DGLGraph): subgraph |
| 140 | """ |
| 141 | sample_nodes = [target_nodes] |
| 142 | frontiers = target_nodes |
| 143 | |
| 144 | for i in range(self.hop): |
| 145 | frontiers = self.graph.out_edges(frontiers)[1] |
| 146 | frontiers = torch.unique(frontiers) |
| 147 | sample_nodes.append(frontiers) |
| 148 | |
| 149 | sample_nodes = torch.cat(sample_nodes) |
| 150 | sample_nodes = torch.unique(sample_nodes) |
| 151 | subgraph = dgl.node_subgraph(self.graph, sample_nodes) |
| 152 | |
| 153 | # Each node should have unique node id in the new subgraph |
| 154 | u_id = int( |
| 155 | torch.nonzero( |
| 156 | subgraph.ndata[NID] == int(target_nodes[0]), as_tuple=False |
| 157 | ) |
| 158 | ) |
| 159 | v_id = int( |
| 160 | torch.nonzero( |
| 161 | subgraph.ndata[NID] == int(target_nodes[1]), as_tuple=False |
| 162 | ) |
| 163 | ) |
| 164 | |
| 165 | # remove link between target nodes in positive subgraphs. |
| 166 | if subgraph.has_edges_between(u_id, v_id): |
| 167 | link_id = subgraph.edge_ids(u_id, v_id, return_uv=True)[2] |
| 168 | subgraph.remove_edges(link_id) |
| 169 | if subgraph.has_edges_between(v_id, u_id): |
| 170 | link_id = subgraph.edge_ids(v_id, u_id, return_uv=True)[2] |
| 171 | subgraph.remove_edges(link_id) |
| 172 | |
| 173 | z = drnl_node_labeling(subgraph, u_id, v_id) |
| 174 | subgraph.ndata["z"] = z |
| 175 | |
| 176 | return subgraph |
| 177 | |
| 178 | def _collate(self, batch): |
| 179 | batch_graphs, batch_labels = map(list, zip(*batch)) |