MCPcopy
hub / github.com/dmlc/dgl / test_gat_conv

Function test_gat_conv

tests/python/mxnet/test_nn.py:191–206  ·  view source on GitHub ↗
(g, idtype, out_dim, num_heads)

Source from the content-addressed store, hash-verified

189@pytest.mark.parametrize("out_dim", [1, 20])
190@pytest.mark.parametrize("num_heads", [1, 5])
191def test_gat_conv(g, idtype, out_dim, num_heads):
192 g = g.astype(idtype).to(F.ctx())
193 ctx = F.ctx()
194 gat = nn.GATConv(10, out_dim, num_heads) # n_heads = 5
195 gat.initialize(ctx=ctx)
196 print(gat)
197 feat = F.randn((g.number_of_src_nodes(), 10))
198 h = gat(g, feat)
199 assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
200 _, a = gat(g, feat, True)
201 assert a.shape == (g.num_edges(), num_heads, 1)
202
203 # test residual connection
204 gat = nn.GATConv(10, out_dim, num_heads, residual=True)
205 gat.initialize(ctx=ctx)
206 h = gat(g, feat)
207
208
209@parametrize_idtype

Callers 1

test_nn.pyFile · 0.70

Calls 6

number_of_src_nodesMethod · 0.80
number_of_dst_nodesMethod · 0.80
toMethod · 0.45
astypeMethod · 0.45
ctxMethod · 0.45
num_edgesMethod · 0.45

Tested by

no test coverage detected