MCPcopy
hub / github.com/dmlc/dgl / unbatch

Function unbatch

python/dgl/batch.py:256–443  ·  view source on GitHub ↗

Revert the batch operation by split the given graph into a list of small ones. This is the reverse operation of :func:``dgl.batch``. If the ``node_split`` or the ``edge_split`` is not given, it calls :func:`DGLGraph.batch_num_nodes` and :func:`DGLGraph.batch_num_edges` of the input grap

(g, node_split=None, edge_split=None)

Source from the content-addressed store, hash-verified

254
255
256def unbatch(g, node_split=None, edge_split=None):
257 """Revert the batch operation by split the given graph into a list of small ones.
258
259 This is the reverse operation of :func:``dgl.batch``. If the ``node_split``
260 or the ``edge_split`` is not given, it calls :func:`DGLGraph.batch_num_nodes`
261 and :func:`DGLGraph.batch_num_edges` of the input graph to get the information.
262
263 If the ``node_split`` or the ``edge_split`` arguments are given,
264 it will partition the graph according to the given segments. One must assure
265 that the partition is valid -- edges of the i^th graph only connect nodes
266 belong to the i^th graph. Otherwise, DGL will throw an error.
267
268 The function supports heterograph input, in which case the two split
269 section arguments shall be of dictionary type -- similar to the
270 :func:`DGLGraph.batch_num_nodes`
271 and :func:`DGLGraph.batch_num_edges` attributes of a heterograph.
272
273 Parameters
274 ----------
275 g : DGLGraph
276 Input graph to unbatch.
277 node_split : Tensor, dict[str, Tensor], optional
278 Number of nodes of each result graph.
279 edge_split : Tensor, dict[str, Tensor], optional
280 Number of edges of each result graph.
281
282 Returns
283 -------
284 list[DGLGraph]
285 Unbatched list of graphs.
286
287 Examples
288 --------
289
290 Unbatch a batched graph
291
292 >>> import dgl
293 >>> import torch as th
294 >>> # 4 nodes, 3 edges
295 >>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))
296 >>> # 3 nodes, 4 edges
297 >>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))
298 >>> # add features
299 >>> g1.ndata['x'] = th.zeros(g1.num_nodes(), 3)
300 >>> g1.edata['w'] = th.ones(g1.num_edges(), 2)
301 >>> g2.ndata['x'] = th.ones(g2.num_nodes(), 3)
302 >>> g2.edata['w'] = th.zeros(g2.num_edges(), 2)
303 >>> bg = dgl.batch([g1, g2])
304 >>> f1, f2 = dgl.unbatch(bg)
305 >>> f1
306 Graph(num_nodes=4, num_edges=3,
307 ndata_schemes={‘x’ : Scheme(shape=(3,), dtype=torch.float32)}
308 edata_schemes={‘w’ : Scheme(shape=(2,), dtype=torch.float32)})
309 >>> f2
310 Graph(num_nodes=3, num_edges=4,
311 ndata_schemes={‘x’ : Scheme(shape=(3,), dtype=torch.float32)}
312 edata_schemes={‘w’ : Scheme(shape=(2,), dtype=torch.float32)})
313

Callers 1

decodeMethod · 0.90

Calls 8

DGLErrorClass · 0.85
batch_num_nodesMethod · 0.80
batch_num_edgesMethod · 0.80
asnumpyMethod · 0.80
keysMethod · 0.45
valuesMethod · 0.45
itemsMethod · 0.45
edgesMethod · 0.45

Tested by

no test coverage detected