MCPcopy Index your code
hub / github.com/dmlc/dgl / test_glob_att_pool

Function test_glob_att_pool

tests/python/tensorflow/test_nn.py:201–216  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

199
200
201def test_glob_att_pool():
202 g = dgl.DGLGraph(nx.path_graph(10)).to(F.ctx())
203
204 gap = nn.GlobalAttentionPooling(layers.Dense(1), layers.Dense(10))
205 print(gap)
206
207 # test#1: basic
208 h0 = F.randn((g.num_nodes(), 5))
209 h1 = gap(g, h0)
210 assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2
211
212 # test#2: batched graph
213 bg = dgl.batch([g, g, g, g])
214 h0 = F.randn((bg.num_nodes(), 5))
215 h1 = gap(bg, h0)
216 assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2
217
218
219@pytest.mark.parametrize("O", [1, 2, 8])

Callers 1

test_nn.pyFile · 0.70

Calls 3

toMethod · 0.45
ctxMethod · 0.45
num_nodesMethod · 0.45

Tested by

no test coverage detected