(idtype, out_dim)
| 31 | @parametrize_idtype |
| 32 | @pytest.mark.parametrize("out_dim", [1, 2]) |
| 33 | def test_graph_conv(idtype, out_dim): |
| 34 | g = dgl.from_networkx(nx.path_graph(3)) |
| 35 | g = g.astype(idtype).to(F.ctx()) |
| 36 | ctx = F.ctx() |
| 37 | adj = g.adj_external(transpose=True, ctx=ctx) |
| 38 | |
| 39 | conv = nn.GraphConv(5, out_dim, norm="none", bias=True) |
| 40 | conv.initialize(ctx=ctx) |
| 41 | # test#1: basic |
| 42 | h0 = F.ones((3, 5)) |
| 43 | h1 = conv(g, h0) |
| 44 | assert len(g.ndata) == 0 |
| 45 | assert len(g.edata) == 0 |
| 46 | check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias)) |
| 47 | # test#2: more-dim |
| 48 | h0 = F.ones((3, 5, 5)) |
| 49 | h1 = conv(g, h0) |
| 50 | assert len(g.ndata) == 0 |
| 51 | assert len(g.edata) == 0 |
| 52 | check_close(h1, _AXWb(adj, h0, conv.weight, conv.bias)) |
| 53 | |
| 54 | conv = nn.GraphConv(5, out_dim) |
| 55 | conv.initialize(ctx=ctx) |
| 56 | |
| 57 | # test#3: basic |
| 58 | h0 = F.ones((3, 5)) |
| 59 | h1 = conv(g, h0) |
| 60 | assert len(g.ndata) == 0 |
| 61 | assert len(g.edata) == 0 |
| 62 | # test#4: basic |
| 63 | h0 = F.ones((3, 5, 5)) |
| 64 | h1 = conv(g, h0) |
| 65 | assert len(g.ndata) == 0 |
| 66 | assert len(g.edata) == 0 |
| 67 | |
| 68 | conv = nn.GraphConv(5, out_dim) |
| 69 | conv.initialize(ctx=ctx) |
| 70 | |
| 71 | with autograd.train_mode(): |
| 72 | # test#3: basic |
| 73 | h0 = F.ones((3, 5)) |
| 74 | h1 = conv(g, h0) |
| 75 | assert len(g.ndata) == 0 |
| 76 | assert len(g.edata) == 0 |
| 77 | # test#4: basic |
| 78 | h0 = F.ones((3, 5, 5)) |
| 79 | h1 = conv(g, h0) |
| 80 | assert len(g.ndata) == 0 |
| 81 | assert len(g.edata) == 0 |
| 82 | |
| 83 | # test not override features |
| 84 | g.ndata["h"] = 2 * F.ones((3, 1)) |
| 85 | h1 = conv(g, h0) |
| 86 | assert len(g.ndata) == 1 |
| 87 | assert len(g.edata) == 0 |
| 88 | assert "h" in g.ndata |
| 89 | check_close(g.ndata["h"], 2 * F.ones((3, 1))) |
| 90 |
no test coverage detected