MCPcopy
hub / github.com/dmlc/dgl / test_gg_conv

Function test_gg_conv

tests/python/mxnet/test_nn.py:284–296  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

282
283
284def test_gg_conv():
285 g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx())
286 ctx = F.ctx()
287
288 gg_conv = nn.GatedGraphConv(10, 20, 3, 4) # n_step = 3, n_etypes = 4
289 gg_conv.initialize(ctx=ctx)
290 print(gg_conv)
291
292 # test#1: basic
293 h0 = F.randn((20, 10))
294 etypes = nd.random.randint(0, 4, g.num_edges()).as_in_context(ctx)
295 h1 = gg_conv(g, h0, etypes)
296 assert h1.shape == (20, 20)
297
298
299@pytest.mark.parametrize("out_dim", [1, 20])

Callers 1

test_nn.pyFile · 0.85

Calls 3

toMethod · 0.45
ctxMethod · 0.45
num_edgesMethod · 0.45

Tested by

no test coverage detected