MCPcopy Index your code
hub / github.com/dmlc/dgl / test_simple_pool

Function test_simple_pool

tests/python/tensorflow/test_nn.py:133–198  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

131
132
133def test_simple_pool():
134 ctx = F.ctx()
135 g = dgl.DGLGraph(nx.path_graph(15)).to(F.ctx())
136
137 sum_pool = nn.SumPooling()
138 avg_pool = nn.AvgPooling()
139 max_pool = nn.MaxPooling()
140 sort_pool = nn.SortPooling(10) # k = 10
141 print(sum_pool, avg_pool, max_pool, sort_pool)
142
143 # test#1: basic
144 h0 = F.randn((g.num_nodes(), 5))
145 h1 = sum_pool(g, h0)
146 assert F.allclose(F.squeeze(h1, 0), F.sum(h0, 0))
147 h1 = avg_pool(g, h0)
148 assert F.allclose(F.squeeze(h1, 0), F.mean(h0, 0))
149 h1 = max_pool(g, h0)
150 assert F.allclose(F.squeeze(h1, 0), F.max(h0, 0))
151 h1 = sort_pool(g, h0)
152 assert h1.shape[0] == 1 and h1.shape[1] == 10 * 5 and h1.ndim == 2
153
154 # test#2: batched graph
155 g_ = dgl.DGLGraph(nx.path_graph(5)).to(F.ctx())
156 bg = dgl.batch([g, g_, g, g_, g])
157 h0 = F.randn((bg.num_nodes(), 5))
158 h1 = sum_pool(bg, h0)
159 truth = tf.stack(
160 [
161 F.sum(h0[:15], 0),
162 F.sum(h0[15:20], 0),
163 F.sum(h0[20:35], 0),
164 F.sum(h0[35:40], 0),
165 F.sum(h0[40:55], 0),
166 ],
167 0,
168 )
169 assert F.allclose(h1, truth)
170
171 h1 = avg_pool(bg, h0)
172 truth = tf.stack(
173 [
174 F.mean(h0[:15], 0),
175 F.mean(h0[15:20], 0),
176 F.mean(h0[20:35], 0),
177 F.mean(h0[35:40], 0),
178 F.mean(h0[40:55], 0),
179 ],
180 0,
181 )
182 assert F.allclose(h1, truth)
183
184 h1 = max_pool(bg, h0)
185 truth = tf.stack(
186 [
187 F.max(h0[:15], 0),
188 F.max(h0[15:20], 0),
189 F.max(h0[20:35], 0),
190 F.max(h0[35:40], 0),

Callers 1

test_nn.pyFile · 0.70

Calls 3

ctxMethod · 0.45
toMethod · 0.45
num_nodesMethod · 0.45

Tested by

no test coverage detected