MCPcopy
hub / github.com/dmlc/dgl / test_graph_conv

Function test_graph_conv

tests/python/mxnet/test_nn.py:33–89  ·  view source on GitHub ↗
(idtype, out_dim)

Source from the content-addressed store, hash-verified

31@parametrize_idtype
32@pytest.mark.parametrize("out_dim", [1, 2])
33def test_graph_conv(idtype, out_dim):
34 g = dgl.from_networkx(nx.path_graph(3))
35 g = g.astype(idtype).to(F.ctx())
36 ctx = F.ctx()
37 adj = g.adj_external(transpose=True, ctx=ctx)
38
39 conv = nn.GraphConv(5, out_dim, norm="none", bias=True)
40 conv.initialize(ctx=ctx)
41 # test#1: basic
42 h0 = F.ones((3, 5))
43 h1 = conv(g, h0)
44 assert len(g.ndata) == 0
45 assert len(g.edata) == 0
46 check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
47 # test#2: more-dim
48 h0 = F.ones((3, 5, 5))
49 h1 = conv(g, h0)
50 assert len(g.ndata) == 0
51 assert len(g.edata) == 0
52 check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias))
53
54 conv = nn.GraphConv(5, out_dim)
55 conv.initialize(ctx=ctx)
56
57 # test#3: basic
58 h0 = F.ones((3, 5))
59 h1 = conv(g, h0)
60 assert len(g.ndata) == 0
61 assert len(g.edata) == 0
62 # test#4: basic
63 h0 = F.ones((3, 5, 5))
64 h1 = conv(g, h0)
65 assert len(g.ndata) == 0
66 assert len(g.edata) == 0
67
68 conv = nn.GraphConv(5, out_dim)
69 conv.initialize(ctx=ctx)
70
71 with autograd.train_mode():
72 # test#3: basic
73 h0 = F.ones((3, 5))
74 h1 = conv(g, h0)
75 assert len(g.ndata) == 0
76 assert len(g.edata) == 0
77 # test#4: basic
78 h0 = F.ones((3, 5, 5))
79 h1 = conv(g, h0)
80 assert len(g.ndata) == 0
81 assert len(g.edata) == 0
82
83 # test not override features
84 g.ndata["h"] = 2 * F.ones((3, 1))
85 h1 = conv(g, h0)
86 assert len(g.ndata) == 1
87 assert len(g.edata) == 0
88 assert "h" in g.ndata
89 check_close(g.ndata["h"], 2 * F.ones((3, 1)))
90

Callers 1

test_nn.pyFile · 0.70

Calls 6

check_closeFunction · 0.85
adj_externalMethod · 0.80
_AXWbFunction · 0.70
toMethod · 0.45
astypeMethod · 0.45
ctxMethod · 0.45

Tested by

no test coverage detected