MCPcopy
hub / github.com/dmlc/dgl / test_set_batch_info

Function test_set_batch_info

tests/python/common/test_batch-graph.py:274–319  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

272
273@parametrize_idtype
274def test_set_batch_info(idtype):
275 ctx = F.ctx()
276
277 g1 = dgl.rand_graph(30, 100).astype(idtype).to(F.ctx())
278 g2 = dgl.rand_graph(40, 200).astype(idtype).to(F.ctx())
279 bg = dgl.batch([g1, g2])
280 batch_num_nodes = F.astype(bg.batch_num_nodes(), idtype)
281 batch_num_edges = F.astype(bg.batch_num_edges(), idtype)
282
283 # test homogeneous node subgraph
284 sg_n = dgl.node_subgraph(bg, list(range(10, 20)) + list(range(50, 60)))
285 induced_nodes = sg_n.ndata["_ID"]
286 induced_edges = sg_n.edata["_ID"]
287 new_batch_num_nodes = _get_subgraph_batch_info(
288 bg.ntypes, [induced_nodes], batch_num_nodes
289 )
290 new_batch_num_edges = _get_subgraph_batch_info(
291 bg.canonical_etypes, [induced_edges], batch_num_edges
292 )
293 sg_n.set_batch_num_nodes(new_batch_num_nodes)
294 sg_n.set_batch_num_edges(new_batch_num_edges)
295 subg_n1, subg_n2 = dgl.unbatch(sg_n)
296 subg1 = dgl.node_subgraph(g1, list(range(10, 20)))
297 subg2 = dgl.node_subgraph(g2, list(range(20, 30)))
298 assert subg_n1.num_edges() == subg1.num_edges()
299 assert subg_n2.num_edges() == subg2.num_edges()
300
301 # test homogeneous edge subgraph
302 sg_e = dgl.edge_subgraph(
303 bg, list(range(40, 70)) + list(range(150, 200)), relabel_nodes=False
304 )
305 induced_nodes = F.arange(0, bg.num_nodes(), idtype)
306 induced_edges = sg_e.edata["_ID"]
307 new_batch_num_nodes = _get_subgraph_batch_info(
308 bg.ntypes, [induced_nodes], batch_num_nodes
309 )
310 new_batch_num_edges = _get_subgraph_batch_info(
311 bg.canonical_etypes, [induced_edges], batch_num_edges
312 )
313 sg_e.set_batch_num_nodes(new_batch_num_nodes)
314 sg_e.set_batch_num_edges(new_batch_num_edges)
315 subg_e1, subg_e2 = dgl.unbatch(sg_e)
316 subg1 = dgl.edge_subgraph(g1, list(range(40, 70)), relabel_nodes=False)
317 subg2 = dgl.edge_subgraph(g2, list(range(50, 100)), relabel_nodes=False)
318 assert subg_e1.num_nodes() == subg1.num_nodes()
319 assert subg_e2.num_nodes() == subg2.num_nodes()
320
321
322if __name__ == "__main__":

Callers 1

Calls 12

_get_subgraph_batch_infoFunction · 0.85
batch_num_nodesMethod · 0.80
batch_num_edgesMethod · 0.80
set_batch_num_nodesMethod · 0.80
set_batch_num_edgesMethod · 0.80
ctxMethod · 0.45
toMethod · 0.45
astypeMethod · 0.45
node_subgraphMethod · 0.45
num_edgesMethod · 0.45
edge_subgraphMethod · 0.45
num_nodesMethod · 0.45

Tested by

no test coverage detected