Get a particular graph from a batch of graphs. Parameters ---------- g : DGLGraph Input batched graph. gid : int The ID of the graph to retrieve. store_ids : bool If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata`` and
(g, gid, store_ids=False)
| 444 | |
| 445 | |
| 446 | def slice_batch(g, gid, store_ids=False): |
| 447 | """Get a particular graph from a batch of graphs. |
| 448 | |
| 449 | Parameters |
| 450 | ---------- |
| 451 | g : DGLGraph |
| 452 | Input batched graph. |
| 453 | gid : int |
| 454 | The ID of the graph to retrieve. |
| 455 | store_ids : bool |
| 456 | If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata`` and |
| 457 | ``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``, respectively. |
| 458 | |
| 459 | Returns |
| 460 | ------- |
| 461 | DGLGraph |
| 462 | Retrieved graph. |
| 463 | |
| 464 | Examples |
| 465 | -------- |
| 466 | |
| 467 | The following example uses PyTorch backend. |
| 468 | |
| 469 | >>> import dgl |
| 470 | >>> import torch |
| 471 | |
| 472 | Create a batched graph. |
| 473 | |
| 474 | >>> g1 = dgl.graph(([0, 1], [2, 3])) |
| 475 | >>> g2 = dgl.graph(([1], [2])) |
| 476 | >>> bg = dgl.batch([g1, g2]) |
| 477 | |
| 478 | Get the second component graph. |
| 479 | |
| 480 | >>> g = dgl.slice_batch(bg, 1) |
| 481 | >>> print(g) |
| 482 | Graph(num_nodes=3, num_edges=1, |
| 483 | ndata_schemes={} |
| 484 | edata_schemes={}) |
| 485 | """ |
| 486 | start_nid = [] |
| 487 | num_nodes = [] |
| 488 | for ntype in g.ntypes: |
| 489 | batch_num_nodes = g.batch_num_nodes(ntype) |
| 490 | num_nodes.append(F.as_scalar(batch_num_nodes[gid])) |
| 491 | if gid == 0: |
| 492 | start_nid.append(0) |
| 493 | else: |
| 494 | start_nid.append( |
| 495 | F.as_scalar(F.sum(F.slice_axis(batch_num_nodes, 0, 0, gid), 0)) |
| 496 | ) |
| 497 | |
| 498 | start_eid = [] |
| 499 | num_edges = [] |
| 500 | for etype in g.canonical_etypes: |
| 501 | batch_num_edges = g.batch_num_edges(etype) |
| 502 | num_edges.append(F.as_scalar(batch_num_edges[gid])) |
| 503 | if gid == 0: |
nothing calls this directly
no test coverage detected