()
| 199 | |
| 200 | |
| 201 | def test_glob_att_pool(): |
| 202 | g = dgl.DGLGraph(nx.path_graph(10)).to(F.ctx()) |
| 203 | |
| 204 | gap = nn.GlobalAttentionPooling(layers.Dense(1), layers.Dense(10)) |
| 205 | print(gap) |
| 206 | |
| 207 | # test#1: basic |
| 208 | h0 = F.randn((g.num_nodes(), 5)) |
| 209 | h1 = gap(g, h0) |
| 210 | assert h1.shape[0] == 1 and h1.shape[1] == 10 and h1.ndim == 2 |
| 211 | |
| 212 | # test#2: batched graph |
| 213 | bg = dgl.batch([g, g, g, g]) |
| 214 | h0 = F.randn((bg.num_nodes(), 5)) |
| 215 | h1 = gap(bg, h0) |
| 216 | assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2 |
| 217 | |
| 218 | |
| 219 | @pytest.mark.parametrize("O", [1, 2, 8]) |
no test coverage detected