(g, idtype)
| 316 | "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"]) |
| 317 | ) |
| 318 | def test_agnn_conv(g, idtype): |
| 319 | g = g.astype(idtype).to(F.ctx()) |
| 320 | ctx = F.ctx() |
| 321 | agnn_conv = nn.AGNNConv(0.1, True) |
| 322 | agnn_conv.initialize(ctx=ctx) |
| 323 | print(agnn_conv) |
| 324 | feat = F.randn((g.number_of_src_nodes(), 10)) |
| 325 | h = agnn_conv(g, feat) |
| 326 | assert h.shape == (g.number_of_dst_nodes(), 10) |
| 327 | |
| 328 | |
| 329 | @parametrize_idtype |
no test coverage detected