(out_dim)
| 28 | |
| 29 | @pytest.mark.parametrize("out_dim", [1, 2]) |
| 30 | def test_graph_conv(out_dim): |
| 31 | g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx()) |
| 32 | ctx = F.ctx() |
| 33 | adj = tf.sparse.to_dense( |
| 34 | tf.sparse.reorder(g.adj_external(transpose=True, ctx=ctx)) |
| 35 | ) |
| 36 | |
| 37 | conv = nn.GraphConv(5, out_dim, norm="none", bias=True) |
| 38 | # conv = conv |
| 39 | print(conv) |
| 40 | # test#1: basic |
| 41 | h0 = F.ones((3, 5)) |
| 42 | h1 = conv(g, h0) |
| 43 | assert len(g.ndata) == 0 |
| 44 | assert len(g.edata) == 0 |
| 45 | assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias)) |
| 46 | # test#2: more-dim |
| 47 | h0 = F.ones((3, 5, 5)) |
| 48 | h1 = conv(g, h0) |
| 49 | assert len(g.ndata) == 0 |
| 50 | assert len(g.edata) == 0 |
| 51 | assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias)) |
| 52 | |
| 53 | conv = nn.GraphConv(5, out_dim) |
| 54 | # conv = conv |
| 55 | # test#3: basic |
| 56 | h0 = F.ones((3, 5)) |
| 57 | h1 = conv(g, h0) |
| 58 | assert len(g.ndata) == 0 |
| 59 | assert len(g.edata) == 0 |
| 60 | # test#4: basic |
| 61 | h0 = F.ones((3, 5, 5)) |
| 62 | h1 = conv(g, h0) |
| 63 | assert len(g.ndata) == 0 |
| 64 | assert len(g.edata) == 0 |
| 65 | |
| 66 | conv = nn.GraphConv(5, out_dim) |
| 67 | # conv = conv |
| 68 | # test#3: basic |
| 69 | h0 = F.ones((3, 5)) |
| 70 | h1 = conv(g, h0) |
| 71 | assert len(g.ndata) == 0 |
| 72 | assert len(g.edata) == 0 |
| 73 | # test#4: basic |
| 74 | h0 = F.ones((3, 5, 5)) |
| 75 | h1 = conv(g, h0) |
| 76 | assert len(g.ndata) == 0 |
| 77 | assert len(g.edata) == 0 |
| 78 | |
| 79 | # test rest_parameters |
| 80 | # old_weight = deepcopy(conv.weight.data) |
| 81 | # conv.reset_parameters() |
| 82 | # new_weight = conv.weight.data |
| 83 | # assert not F.allclose(old_weight, new_weight) |
| 84 | |
| 85 | |
| 86 | @parametrize_idtype |
no test coverage detected