(g, idtype, out_dim, num_heads)
| 211 | @pytest.mark.parametrize("out_dim", [1, 2]) |
| 212 | @pytest.mark.parametrize("num_heads", [1, 4]) |
| 213 | def test_gat_conv_bi(g, idtype, out_dim, num_heads): |
| 214 | g = g.astype(idtype).to(F.ctx()) |
| 215 | ctx = F.ctx() |
| 216 | gat = nn.GATConv(5, out_dim, num_heads) |
| 217 | gat.initialize(ctx=ctx) |
| 218 | feat = ( |
| 219 | F.randn((g.number_of_src_nodes(), 5)), |
| 220 | F.randn((g.number_of_dst_nodes(), 5)), |
| 221 | ) |
| 222 | h = gat(g, feat) |
| 223 | assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim) |
| 224 | _, a = gat(g, feat, True) |
| 225 | assert a.shape == (g.num_edges(), num_heads, 1) |
| 226 | |
| 227 | |
| 228 | @parametrize_idtype |
nothing calls this directly
no test coverage detected