Manually set the number of nodes for each graph in the batch with the specified node type. Parameters ---------- val : Tensor or Mapping[str, Tensor] The dictionary storing number of nodes for each graph in the batch for all node types. If the
(self, val)
| 1524 | return self._batch_num_nodes[ntype] |
| 1525 | |
| 1526 | def set_batch_num_nodes(self, val): |
| 1527 | """Manually set the number of nodes for each graph in the batch with the specified node |
| 1528 | type. |
| 1529 | |
| 1530 | Parameters |
| 1531 | ---------- |
| 1532 | val : Tensor or Mapping[str, Tensor] |
| 1533 | The dictionary storing number of nodes for each graph in the batch for all node types. |
| 1534 | If the graph has only one node type, ``val`` can also be a single array indicating the |
| 1535 | number of nodes per graph in the batch. |
| 1536 | |
| 1537 | Notes |
| 1538 | ----- |
| 1539 | This API is always used together with ``set_batch_num_edges`` to specify batching |
| 1540 | information of a graph, it also do not check the correspondance between the graph structure |
| 1541 | and batching information and user must guarantee there will be no cross-graph edges in the |
| 1542 | batch. |
| 1543 | |
| 1544 | Examples |
| 1545 | -------- |
| 1546 | |
| 1547 | The following example uses PyTorch backend. |
| 1548 | |
| 1549 | >>> import dgl |
| 1550 | >>> import torch |
| 1551 | |
| 1552 | Create a homogeneous graph. |
| 1553 | |
| 1554 | >>> g = dgl.graph(([0, 1, 2, 3, 4, 5], [1, 2, 0, 4, 5, 3])) |
| 1555 | |
| 1556 | Manually set batch information |
| 1557 | |
| 1558 | >>> g.set_batch_num_nodes(torch.tensor([3, 3])) |
| 1559 | >>> g.set_batch_num_edges(torch.tensor([3, 3])) |
| 1560 | |
| 1561 | Unbatch the graph. |
| 1562 | |
| 1563 | >>> dgl.unbatch(g) |
| 1564 | [Graph(num_nodes=3, num_edges=3, |
| 1565 | ndata_schemes={} |
| 1566 | edata_schemes={}), Graph(num_nodes=3, num_edges=3, |
| 1567 | ndata_schemes={} |
| 1568 | edata_schemes={})] |
| 1569 | |
| 1570 | Create a heterogeneous graph. |
| 1571 | |
| 1572 | >>> hg = dgl.heterograph({ |
| 1573 | ... ('user', 'plays', 'game') : ([0, 1, 2, 3, 4, 5], [0, 1, 1, 3, 3, 2]), |
| 1574 | ... ('developer', 'develops', 'game') : ([0, 1, 2, 3], [1, 0, 3, 2])}) |
| 1575 | |
| 1576 | Manually set batch information. |
| 1577 | |
| 1578 | >>> hg.set_batch_num_nodes({ |
| 1579 | ... 'user': torch.tensor([3, 3]), |
| 1580 | ... 'game': torch.tensor([2, 2]), |
| 1581 | ... 'developer': torch.tensor([2, 2])}) |
| 1582 | >>> hg.set_batch_num_edges({ |
| 1583 | ... ('user', 'plays', 'game'): torch.tensor([3, 3]), |