MCPcopy
hub / github.com/dmlc/dgl / test_empty_relation

Function test_empty_relation

tests/python/common/test_batch-heterograph.py:337–428  ·  view source on GitHub ↗

Test the features of batched DGLGraphs

(idtype)

Source from the content-addressed store, hash-verified

335)
336@parametrize_idtype
337def test_empty_relation(idtype):
338 """Test the features of batched DGLGraphs"""
339 g1 = dgl.heterograph(
340 {
341 ("user", "follows", "user"): ([0, 1], [1, 2]),
342 ("user", "plays", "game"): ([], []),
343 },
344 idtype=idtype,
345 device=F.ctx(),
346 )
347 g1.nodes["user"].data["h1"] = F.tensor([[0.0], [1.0], [2.0]])
348 g1.nodes["user"].data["h2"] = F.tensor([[3.0], [4.0], [5.0]])
349 g1.edges["follows"].data["h1"] = F.tensor([[0.0], [1.0]])
350 g1.edges["follows"].data["h2"] = F.tensor([[2.0], [3.0]])
351
352 g2 = dgl.heterograph(
353 {
354 ("user", "follows", "user"): ([0, 1], [1, 2]),
355 ("user", "plays", "game"): ([0, 1], [0, 0]),
356 },
357 idtype=idtype,
358 device=F.ctx(),
359 )
360 g2.nodes["user"].data["h1"] = F.tensor([[0.0], [1.0], [2.0]])
361 g2.nodes["user"].data["h2"] = F.tensor([[3.0], [4.0], [5.0]])
362 g2.nodes["game"].data["h1"] = F.tensor([[0.0]])
363 g2.nodes["game"].data["h2"] = F.tensor([[1.0]])
364 g2.edges["follows"].data["h1"] = F.tensor([[0.0], [1.0]])
365 g2.edges["follows"].data["h2"] = F.tensor([[2.0], [3.0]])
366 g2.edges["plays"].data["h1"] = F.tensor([[0.0], [1.0]])
367
368 bg = dgl.batch([g1, g2])
369
370 # Test number of nodes
371 for ntype in bg.ntypes:
372 assert F.asnumpy(bg.batch_num_nodes(ntype)).tolist() == [
373 g1.num_nodes(ntype),
374 g2.num_nodes(ntype),
375 ]
376
377 # Test number of edges
378 for etype in bg.canonical_etypes:
379 assert F.asnumpy(bg.batch_num_edges(etype)).tolist() == [
380 g1.num_edges(etype),
381 g2.num_edges(etype),
382 ]
383
384 # Test features
385 assert F.allclose(
386 bg.nodes["user"].data["h1"],
387 F.cat(
388 [g1.nodes["user"].data["h1"], g2.nodes["user"].data["h1"]], dim=0
389 ),
390 )
391 assert F.allclose(
392 bg.nodes["user"].data["h2"],
393 F.cat(
394 [g1.nodes["user"].data["h2"], g2.nodes["user"].data["h2"]], dim=0

Callers

nothing calls this directly

Calls 7

asnumpyMethod · 0.80
batch_num_nodesMethod · 0.80
batch_num_edgesMethod · 0.80
ctxMethod · 0.45
num_nodesMethod · 0.45
num_edgesMethod · 0.45

Tested by

no test coverage detected