(g, idtype, out_dim, num_heads)
| 306 | @pytest.mark.parametrize("out_dim", [1, 2]) |
| 307 | @pytest.mark.parametrize("num_heads", [1, 4]) |
| 308 | def test_gat_conv(g, idtype, out_dim, num_heads): |
| 309 | g = g.astype(idtype).to(F.ctx()) |
| 310 | ctx = F.ctx() |
| 311 | gat = nn.GATConv(5, out_dim, num_heads) |
| 312 | feat = F.randn((g.number_of_src_nodes(), 5)) |
| 313 | h = gat(g, feat) |
| 314 | assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim) |
| 315 | _, a = gat(g, feat, get_attention=True) |
| 316 | assert a.shape == (g.num_edges(), num_heads, 1) |
| 317 | |
| 318 | # test residual connection |
| 319 | gat = nn.GATConv(5, out_dim, num_heads, residual=True) |
| 320 | h = gat(g, feat) |
| 321 | |
| 322 | |
| 323 | @parametrize_idtype |
no test coverage detected