MCPcopy Index your code
hub / github.com/dmlc/dgl / test_rgcn

Function test_rgcn

tests/python/tensorflow/test_nn.py:220–299  ·  view source on GitHub ↗
(O)

Source from the content-addressed store, hash-verified

218
219@pytest.mark.parametrize("O", [1, 2, 8])
220def test_rgcn(O):
221 etype = []
222 g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True).to(
223 F.ctx()
224 )
225 # 5 etypes
226 R = 5
227 for i in range(g.num_edges()):
228 etype.append(i % 5)
229 B = 2
230 I = 10
231
232 rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
233 rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
234 rgc_basis_low.weight = rgc_basis.weight
235 rgc_basis_low.w_comp = rgc_basis.w_comp
236 rgc_basis_low.loop_weight = rgc_basis.loop_weight
237 h = tf.random.normal((100, I))
238 r = tf.constant(etype)
239 h_new = rgc_basis(g, h, r)
240 h_new_low = rgc_basis_low(g, h, r)
241 assert list(h_new.shape) == [100, O]
242 assert list(h_new_low.shape) == [100, O]
243 assert F.allclose(h_new, h_new_low)
244
245 if O % B == 0:
246 rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
247 rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True)
248 rgc_bdd_low.weight = rgc_bdd.weight
249 rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
250 h = tf.random.normal((100, I))
251 r = tf.constant(etype)
252 h_new = rgc_bdd(g, h, r)
253 h_new_low = rgc_bdd_low(g, h, r)
254 assert list(h_new.shape) == [100, O]
255 assert list(h_new_low.shape) == [100, O]
256 assert F.allclose(h_new, h_new_low)
257
258 # with norm
259 norm = tf.zeros((g.num_edges(), 1))
260
261 rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
262 rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True)
263 rgc_basis_low.weight = rgc_basis.weight
264 rgc_basis_low.w_comp = rgc_basis.w_comp
265 rgc_basis_low.loop_weight = rgc_basis.loop_weight
266 h = tf.random.normal((100, I))
267 r = tf.constant(etype)
268 h_new = rgc_basis(g, h, r, norm)
269 h_new_low = rgc_basis_low(g, h, r, norm)
270 assert list(h_new.shape) == [100, O]
271 assert list(h_new_low.shape) == [100, O]
272 assert F.allclose(h_new, h_new_low)
273
274 if O % B == 0:
275 rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
276 rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True)
277 rgc_bdd_low.weight = rgc_bdd.weight

Callers 1

test_nn.pyFile · 0.70

Calls 4

appendMethod · 0.80
toMethod · 0.45
ctxMethod · 0.45
num_edgesMethod · 0.45

Tested by

no test coverage detected