MCPcopy
hub / github.com/dmlc/dgl / test_batch_unbatch1

Function test_batch_unbatch1

tests/python/common/test_batch-graph.py:70–87  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

68
69@parametrize_idtype
70def test_batch_unbatch1(idtype):
71 t1 = tree1(idtype)
72 t2 = tree2(idtype)
73 b1 = dgl.batch([t1, t2])
74 b2 = dgl.batch([t2, b1])
75 assert b2.num_nodes() == 15
76 assert b2.num_edges() == 12
77 assert b2.batch_size == 3
78 assert F.allclose(b2.batch_num_nodes(), F.tensor([5, 5, 5]))
79 assert F.allclose(b2.batch_num_edges(), F.tensor([4, 4, 4]))
80
81 s1, s2, s3 = dgl.unbatch(b2)
82 assert F.allclose(t2.ndata["h"], s1.ndata["h"])
83 assert F.allclose(t2.edata["h"], s1.edata["h"])
84 assert F.allclose(t1.ndata["h"], s2.ndata["h"])
85 assert F.allclose(t1.edata["h"], s2.edata["h"])
86 assert F.allclose(t2.ndata["h"], s3.ndata["h"])
87 assert F.allclose(t2.edata["h"], s3.edata["h"])
88
89
90@unittest.skipIf(

Callers

nothing calls this directly

Calls 6

tree1Function · 0.85
tree2Function · 0.85
batch_num_nodesMethod · 0.80
batch_num_edgesMethod · 0.80
num_nodesMethod · 0.45
num_edgesMethod · 0.45

Tested by

no test coverage detected