MCPcopy
hub / github.com/THUDM/CogDL / batch_sum_pooling

Function batch_sum_pooling

docs/source/examples/1graph.py:47–57  ·  view source on GitHub ↗
(x, batch)

Source from the content-addressed store, hash-verified

45# The following code snippet shows how to do global pooling to sum over features of nodes in each graph:
46# --------------------------------------------------------------------------------------------------------
47def batch_sum_pooling(x, batch):
48 batch_size = int(torch.max(batch.cpu())) + 1
49 res = torch.zeros(batch_size, x.size(1)).to(x.device)
50 out = res.scatter_add_(
51 dim=0,
52 index=batch.unsqueeze(-1).expand_as(x),
53 src=x
54 )
55 return out
56
57 return out
58
59# %%
60# How to edit the graph?

Callers

nothing calls this directly

Calls 1

toMethod · 0.45

Tested by

no test coverage detected