(g, idtype, out_dim)
| 434 | ) |
| 435 | @pytest.mark.parametrize("out_dim", [1, 2]) |
| 436 | def test_edge_conv(g, idtype, out_dim): |
| 437 | g = g.astype(idtype).to(F.ctx()) |
| 438 | ctx = F.ctx() |
| 439 | edge_conv = nn.EdgeConv(5, out_dim) |
| 440 | edge_conv.initialize(ctx=ctx) |
| 441 | print(edge_conv) |
| 442 | # test #1: basic |
| 443 | h0 = F.randn((g.number_of_src_nodes(), 5)) |
| 444 | h1 = edge_conv(g, h0) |
| 445 | assert h1.shape == (g.number_of_dst_nodes(), out_dim) |
| 446 | |
| 447 | |
| 448 | @parametrize_idtype |
no test coverage detected