()
| 282 | |
| 283 | |
| 284 | def test_gg_conv(): |
| 285 | g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx()) |
| 286 | ctx = F.ctx() |
| 287 | |
| 288 | gg_conv = nn.GatedGraphConv(10, 20, 3, 4) # n_step = 3, n_etypes = 4 |
| 289 | gg_conv.initialize(ctx=ctx) |
| 290 | print(gg_conv) |
| 291 | |
| 292 | # test#1: basic |
| 293 | h0 = F.randn((20, 10)) |
| 294 | etypes = nd.random.randint(0, 4, g.num_edges()).as_in_context(ctx) |
| 295 | h1 = gg_conv(g, h0, etypes) |
| 296 | assert h1.shape == (20, 20) |
| 297 | |
| 298 | |
| 299 | @pytest.mark.parametrize("out_dim", [1, 20]) |
no test coverage detected