(idtype)
| 272 | |
| 273 | @parametrize_idtype |
| 274 | def test_set_batch_info(idtype): |
| 275 | ctx = F.ctx() |
| 276 | |
| 277 | g1 = dgl.rand_graph(30, 100).astype(idtype).to(F.ctx()) |
| 278 | g2 = dgl.rand_graph(40, 200).astype(idtype).to(F.ctx()) |
| 279 | bg = dgl.batch([g1, g2]) |
| 280 | batch_num_nodes = F.astype(bg.batch_num_nodes(), idtype) |
| 281 | batch_num_edges = F.astype(bg.batch_num_edges(), idtype) |
| 282 | |
| 283 | # test homogeneous node subgraph |
| 284 | sg_n = dgl.node_subgraph(bg, list(range(10, 20)) + list(range(50, 60))) |
| 285 | induced_nodes = sg_n.ndata["_ID"] |
| 286 | induced_edges = sg_n.edata["_ID"] |
| 287 | new_batch_num_nodes = _get_subgraph_batch_info( |
| 288 | bg.ntypes, [induced_nodes], batch_num_nodes |
| 289 | ) |
| 290 | new_batch_num_edges = _get_subgraph_batch_info( |
| 291 | bg.canonical_etypes, [induced_edges], batch_num_edges |
| 292 | ) |
| 293 | sg_n.set_batch_num_nodes(new_batch_num_nodes) |
| 294 | sg_n.set_batch_num_edges(new_batch_num_edges) |
| 295 | subg_n1, subg_n2 = dgl.unbatch(sg_n) |
| 296 | subg1 = dgl.node_subgraph(g1, list(range(10, 20))) |
| 297 | subg2 = dgl.node_subgraph(g2, list(range(20, 30))) |
| 298 | assert subg_n1.num_edges() == subg1.num_edges() |
| 299 | assert subg_n2.num_edges() == subg2.num_edges() |
| 300 | |
| 301 | # test homogeneous edge subgraph |
| 302 | sg_e = dgl.edge_subgraph( |
| 303 | bg, list(range(40, 70)) + list(range(150, 200)), relabel_nodes=False |
| 304 | ) |
| 305 | induced_nodes = F.arange(0, bg.num_nodes(), idtype) |
| 306 | induced_edges = sg_e.edata["_ID"] |
| 307 | new_batch_num_nodes = _get_subgraph_batch_info( |
| 308 | bg.ntypes, [induced_nodes], batch_num_nodes |
| 309 | ) |
| 310 | new_batch_num_edges = _get_subgraph_batch_info( |
| 311 | bg.canonical_etypes, [induced_edges], batch_num_edges |
| 312 | ) |
| 313 | sg_e.set_batch_num_nodes(new_batch_num_nodes) |
| 314 | sg_e.set_batch_num_edges(new_batch_num_edges) |
| 315 | subg_e1, subg_e2 = dgl.unbatch(sg_e) |
| 316 | subg1 = dgl.edge_subgraph(g1, list(range(40, 70)), relabel_nodes=False) |
| 317 | subg2 = dgl.edge_subgraph(g2, list(range(50, 100)), relabel_nodes=False) |
| 318 | assert subg_e1.num_nodes() == subg1.num_nodes() |
| 319 | assert subg_e2.num_nodes() == subg2.num_nodes() |
| 320 | |
| 321 | |
| 322 | if __name__ == "__main__": |
no test coverage detected