(g, idtype, out_dim, num_heads)
| 189 | @pytest.mark.parametrize("out_dim", [1, 20]) |
| 190 | @pytest.mark.parametrize("num_heads", [1, 5]) |
| 191 | def test_gat_conv(g, idtype, out_dim, num_heads): |
| 192 | g = g.astype(idtype).to(F.ctx()) |
| 193 | ctx = F.ctx() |
| 194 | gat = nn.GATConv(10, out_dim, num_heads) # n_heads = 5 |
| 195 | gat.initialize(ctx=ctx) |
| 196 | print(gat) |
| 197 | feat = F.randn((g.number_of_src_nodes(), 10)) |
| 198 | h = gat(g, feat) |
| 199 | assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim) |
| 200 | _, a = gat(g, feat, True) |
| 201 | assert a.shape == (g.num_edges(), num_heads, 1) |
| 202 | |
| 203 | # test residual connection |
| 204 | gat = nn.GATConv(10, out_dim, num_heads, residual=True) |
| 205 | gat.initialize(ctx=ctx) |
| 206 | h = gat(g, feat) |
| 207 | |
| 208 | |
| 209 | @parametrize_idtype |
no test coverage detected