Revert the batch operation by split the given graph into a list of small ones. This is the reverse operation of :func:``dgl.batch``. If the ``node_split`` or the ``edge_split`` is not given, it calls :func:`DGLGraph.batch_num_nodes` and :func:`DGLGraph.batch_num_edges` of the input grap
(g, node_split=None, edge_split=None)
| 254 | |
| 255 | |
| 256 | def unbatch(g, node_split=None, edge_split=None): |
| 257 | """Revert the batch operation by split the given graph into a list of small ones. |
| 258 | |
| 259 | This is the reverse operation of :func:``dgl.batch``. If the ``node_split`` |
| 260 | or the ``edge_split`` is not given, it calls :func:`DGLGraph.batch_num_nodes` |
| 261 | and :func:`DGLGraph.batch_num_edges` of the input graph to get the information. |
| 262 | |
| 263 | If the ``node_split`` or the ``edge_split`` arguments are given, |
| 264 | it will partition the graph according to the given segments. One must assure |
| 265 | that the partition is valid -- edges of the i^th graph only connect nodes |
| 266 | belong to the i^th graph. Otherwise, DGL will throw an error. |
| 267 | |
| 268 | The function supports heterograph input, in which case the two split |
| 269 | section arguments shall be of dictionary type -- similar to the |
| 270 | :func:`DGLGraph.batch_num_nodes` |
| 271 | and :func:`DGLGraph.batch_num_edges` attributes of a heterograph. |
| 272 | |
| 273 | Parameters |
| 274 | ---------- |
| 275 | g : DGLGraph |
| 276 | Input graph to unbatch. |
| 277 | node_split : Tensor, dict[str, Tensor], optional |
| 278 | Number of nodes of each result graph. |
| 279 | edge_split : Tensor, dict[str, Tensor], optional |
| 280 | Number of edges of each result graph. |
| 281 | |
| 282 | Returns |
| 283 | ------- |
| 284 | list[DGLGraph] |
| 285 | Unbatched list of graphs. |
| 286 | |
| 287 | Examples |
| 288 | -------- |
| 289 | |
| 290 | Unbatch a batched graph |
| 291 | |
| 292 | >>> import dgl |
| 293 | >>> import torch as th |
| 294 | >>> # 4 nodes, 3 edges |
| 295 | >>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3]))) |
| 296 | >>> # 3 nodes, 4 edges |
| 297 | >>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0]))) |
| 298 | >>> # add features |
| 299 | >>> g1.ndata['x'] = th.zeros(g1.num_nodes(), 3) |
| 300 | >>> g1.edata['w'] = th.ones(g1.num_edges(), 2) |
| 301 | >>> g2.ndata['x'] = th.ones(g2.num_nodes(), 3) |
| 302 | >>> g2.edata['w'] = th.zeros(g2.num_edges(), 2) |
| 303 | >>> bg = dgl.batch([g1, g2]) |
| 304 | >>> f1, f2 = dgl.unbatch(bg) |
| 305 | >>> f1 |
| 306 | Graph(num_nodes=4, num_edges=3, |
| 307 | ndata_schemes={‘x’ : Scheme(shape=(3,), dtype=torch.float32)} |
| 308 | edata_schemes={‘w’ : Scheme(shape=(2,), dtype=torch.float32)}) |
| 309 | >>> f2 |
| 310 | Graph(num_nodes=3, num_edges=4, |
| 311 | ndata_schemes={‘x’ : Scheme(shape=(3,), dtype=torch.float32)} |
| 312 | edata_schemes={‘w’ : Scheme(shape=(2,), dtype=torch.float32)}) |
| 313 |
no test coverage detected