()
| 131 | |
| 132 | |
| 133 | def test_simple_pool(): |
| 134 | ctx = F.ctx() |
| 135 | g = dgl.DGLGraph(nx.path_graph(15)).to(F.ctx()) |
| 136 | |
| 137 | sum_pool = nn.SumPooling() |
| 138 | avg_pool = nn.AvgPooling() |
| 139 | max_pool = nn.MaxPooling() |
| 140 | sort_pool = nn.SortPooling(10) # k = 10 |
| 141 | print(sum_pool, avg_pool, max_pool, sort_pool) |
| 142 | |
| 143 | # test#1: basic |
| 144 | h0 = F.randn((g.num_nodes(), 5)) |
| 145 | h1 = sum_pool(g, h0) |
| 146 | assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0)) |
| 147 | h1 = avg_pool(g, h0) |
| 148 | assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0)) |
| 149 | h1 = max_pool(g, h0) |
| 150 | assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0)) |
| 151 | h1 = sort_pool(g, h0) |
| 152 | assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2 |
| 153 | |
| 154 | # test#2: batched graph |
| 155 | g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx()) |
| 156 | bg = dgl.batch([g, g_, g, g_, g]) |
| 157 | h0 = F.randn((bg.num_nodes(), 5)) |
| 158 | h1 = sum_pool(bg, h0) |
| 159 | truth = tf.stack( |
| 160 | [ |
| 161 | F.sum(h0[:15], 0), |
| 162 | F.sum(h0[15:20], 0), |
| 163 | F.sum(h0[20:35], 0), |
| 164 | F.sum(h0[35:40], 0), |
| 165 | F.sum(h0[40:55], 0), |
| 166 | ], |
| 167 | 0, |
| 168 | ) |
| 169 | assert F.allclose(h1, truth) |
| 170 | |
| 171 | h1 = avg_pool(bg, h0) |
| 172 | truth = tf.stack( |
| 173 | [ |
| 174 | F.mean(h0[:15], 0), |
| 175 | F.mean(h0[15:20], 0), |
| 176 | F.mean(h0[20:35], 0), |
| 177 | F.mean(h0[35:40], 0), |
| 178 | F.mean(h0[40:55], 0), |
| 179 | ], |
| 180 | 0, |
| 181 | ) |
| 182 | assert F.allclose(h1, truth) |
| 183 | |
| 184 | h1 = max_pool(bg, h0) |
| 185 | truth = tf.stack( |
| 186 | [ |
| 187 | F.max(h0[:15], 0), |
| 188 | F.max(h0[15:20], 0), |
| 189 | F.max(h0[20:35], 0), |
| 190 | F.max(h0[35:40], 0), |
no test coverage detected