MCPcopy
hub / github.com/dmlc/dgl / test_batching_batched

Function test_batching_batched

tests/python/common/test_batch-heterograph.py:125–200  ·  view source on GitHub ↗

Test batching a DGLGraph and a batched DGLGraph.

(idtype)

Source from the content-addressed store, hash-verified

123
124@parametrize_idtype
125def test_batching_batched(idtype):
126 """Test batching a DGLGraph and a batched DGLGraph."""
127 g1 = dgl.heterograph(
128 {
129 ("user", "follows", "user"): ([0, 1], [1, 2]),
130 ("user", "plays", "game"): ([0, 1], [0, 0]),
131 },
132 idtype=idtype,
133 device=F.ctx(),
134 )
135 g2 = dgl.heterograph(
136 {
137 ("user", "follows", "user"): ([0, 1], [1, 2]),
138 ("user", "plays", "game"): ([0, 1], [0, 0]),
139 },
140 idtype=idtype,
141 device=F.ctx(),
142 )
143 bg1 = dgl.batch([g1, g2])
144 g3 = dgl.heterograph(
145 {
146 ("user", "follows", "user"): ([0], [1]),
147 ("user", "plays", "game"): ([1], [0]),
148 },
149 idtype=idtype,
150 device=F.ctx(),
151 )
152 bg2 = dgl.batch([bg1, g3])
153 assert bg2.idtype == idtype
154 assert bg2.device == F.ctx()
155 assert bg2.ntypes == g3.ntypes
156 assert bg2.etypes == g3.etypes
157 assert bg2.canonical_etypes == g3.canonical_etypes
158 assert bg2.batch_size == 3
159
160 # Test number of nodes
161 for ntype in bg2.ntypes:
162 assert F.asnumpy(bg2.batch_num_nodes(ntype)).tolist() == [
163 g1.num_nodes(ntype),
164 g2.num_nodes(ntype),
165 g3.num_nodes(ntype),
166 ]
167 assert bg2.num_nodes(ntype) == (
168 g1.num_nodes(ntype) + g2.num_nodes(ntype) + g3.num_nodes(ntype)
169 )
170
171 # Test number of edges
172 for etype in bg2.canonical_etypes:
173 assert F.asnumpy(bg2.batch_num_edges(etype)).tolist() == [
174 g1.num_edges(etype),
175 g2.num_edges(etype),
176 g3.num_edges(etype),
177 ]
178 assert bg2.num_edges(etype) == (
179 g1.num_edges(etype) + g2.num_edges(etype) + g3.num_edges(etype)
180 )
181
182 # Test relabeled nodes

Callers

nothing calls this directly

Calls 9

asnumpyMethod · 0.80
batch_num_nodesMethod · 0.80
batch_num_edgesMethod · 0.80
ctxMethod · 0.45
num_nodesMethod · 0.45
num_edgesMethod · 0.45
nodesMethod · 0.45
edgesMethod · 0.45

Tested by

no test coverage detected