MCPcopy
hub / github.com/dmlc/dgl / forward

Method forward

examples/pytorch/sagpool/layer.py:35–53  ·  view source on GitHub ↗
(self, graph: dgl.DGLGraph, feature: torch.Tensor)

Source from the content-addressed store, hash-verified

33 self.non_linearity = non_linearity
34
35 def forward(self, graph: dgl.DGLGraph, feature: torch.Tensor):
36 score = self.score_layer(graph, feature).squeeze()
37 perm, next_batch_num_nodes = topk(
38 score,
39 self.ratio,
40 get_batch_id(graph.batch_num_nodes()),
41 graph.batch_num_nodes(),
42 )
43 feature = feature[perm] * self.non_linearity(score[perm]).view(-1, 1)
44 graph = dgl.node_subgraph(graph, perm)
45
46 # node_subgraph currently does not support batch-graph,
47 # the 'batch_num_nodes' of the result subgraph is None.
48 # So we manually set the 'batch_num_nodes' here.
49 # Since global pooling has nothing to do with 'batch_num_edges',
50 # we can leave it to be None or unchanged.
51 graph.set_batch_num_nodes(next_batch_num_nodes)
52
53 return graph, feature, perm
54
55
56class ConvPoolBlock(torch.nn.Module):

Callers

nothing calls this directly

Calls 5

topkFunction · 0.90
get_batch_idFunction · 0.90
batch_num_nodesMethod · 0.80
set_batch_num_nodesMethod · 0.80
node_subgraphMethod · 0.45

Tested by

no test coverage detected