(self, graph: dgl.DGLGraph, feature: torch.Tensor)
| 33 | self.non_linearity = non_linearity |
| 34 | |
| 35 | def forward(self, graph: dgl.DGLGraph, feature: torch.Tensor): |
| 36 | score = self.score_layer(graph, feature).squeeze() |
| 37 | perm, next_batch_num_nodes = topk( |
| 38 | score, |
| 39 | self.ratio, |
| 40 | get_batch_id(graph.batch_num_nodes()), |
| 41 | graph.batch_num_nodes(), |
| 42 | ) |
| 43 | feature = feature[perm] * self.non_linearity(score[perm]).view(-1, 1) |
| 44 | graph = dgl.node_subgraph(graph, perm) |
| 45 | |
| 46 | # node_subgraph currently does not support batch-graph, |
| 47 | # the 'batch_num_nodes' of the result subgraph is None. |
| 48 | # So we manually set the 'batch_num_nodes' here. |
| 49 | # Since global pooling has nothing to do with 'batch_num_edges', |
| 50 | # we can leave it to be None or unchanged. |
| 51 | graph.set_batch_num_nodes(next_batch_num_nodes) |
| 52 | |
| 53 | return graph, feature, perm |
| 54 | |
| 55 | |
| 56 | class ConvPoolBlock(torch.nn.Module): |
nothing calls this directly
no test coverage detected