(g, idtype)
| 329 | @parametrize_idtype |
| 330 | @pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"])) |
| 331 | def test_agnn_conv_bi(g, idtype): |
| 332 | g = g.astype(idtype).to(F.ctx()) |
| 333 | ctx = F.ctx() |
| 334 | agnn_conv = nn.AGNNConv(0.1, True) |
| 335 | agnn_conv.initialize(ctx=ctx) |
| 336 | print(agnn_conv) |
| 337 | feat = ( |
| 338 | F.randn((g.number_of_src_nodes(), 5)), |
| 339 | F.randn((g.number_of_dst_nodes(), 5)), |
| 340 | ) |
| 341 | h = agnn_conv(g, feat) |
| 342 | assert h.shape == (g.number_of_dst_nodes(), 5) |
| 343 | |
| 344 | |
| 345 | def test_appnp_conv(): |
nothing calls this directly
no test coverage detected