MCPcopy
hub / github.com/dmlc/dgl / batch_num_nodes

Method batch_num_nodes

python/dgl/heterograph.py:1461–1524  ·  view source on GitHub ↗

Return the number of nodes for each graph in the batch with the specified node type. Parameters ---------- ntype : str, optional The node type for query. If the graph has multiple node types, one must specify the argument. Otherwise, it can be omitted

(self, ntype=None)

Source from the content-addressed store, hash-verified

1459 return len(self.batch_num_nodes(self.ntypes[0]))
1460
1461 def batch_num_nodes(self, ntype=None):
1462 """Return the number of nodes for each graph in the batch with the specified node type.
1463
1464 Parameters
1465 ----------
1466 ntype : str, optional
1467 The node type for query. If the graph has multiple node types, one must
1468 specify the argument. Otherwise, it can be omitted. If the graph is not a batched
1469 one, it will return a list of length 1 that holds the number of nodes in the graph.
1470
1471 Returns
1472 -------
1473 Tensor
1474 The number of nodes with the specified type for each graph in the batch. The i-th
1475 element of it is the number of nodes with the specified type for the i-th graph.
1476
1477 Examples
1478 --------
1479
1480 The following example uses PyTorch backend.
1481
1482 >>> import dgl
1483 >>> import torch
1484
1485 Query for homogeneous graphs.
1486
1487 >>> g1 = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])))
1488 >>> g1.batch_num_nodes()
1489 tensor([4])
1490 >>> g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0])))
1491 >>> bg = dgl.batch([g1, g2])
1492 >>> bg.batch_num_nodes()
1493 tensor([4, 3])
1494
1495 Query for heterogeneous graphs.
1496
1497 >>> hg1 = dgl.heterograph({
1498 ... ('user', 'plays', 'game') : (torch.tensor([0, 1]), torch.tensor([0, 0]))})
1499 >>> hg2 = dgl.heterograph({
1500 ... ('user', 'plays', 'game') : (torch.tensor([0, 0]), torch.tensor([1, 0]))})
1501 >>> bg = dgl.batch([hg1, hg2])
1502 >>> bg.batch_num_nodes('user')
1503 tensor([2, 1])
1504 """
1505 if ntype is not None and ntype not in self.ntypes:
1506 raise DGLError(
1507 "Expect ntype in {}, got {}".format(self.ntypes, ntype)
1508 )
1509
1510 if self._batch_num_nodes is None:
1511 self._batch_num_nodes = {}
1512 for ty in self.ntypes:
1513 bnn = F.copy_to(
1514 F.tensor([self.num_nodes(ty)], self.idtype), self.device
1515 )
1516 self._batch_num_nodes[ty] = bnn
1517 if ntype is None:
1518 if len(self.ntypes) != 1:

Callers 15

batch_sizeMethod · 0.95
readout_nodesFunction · 0.80
softmax_nodesFunction · 0.80
broadcast_nodesFunction · 0.80
batchFunction · 0.80
unbatchFunction · 0.80
slice_batchFunction · 0.80
forwardMethod · 0.80
forwardMethod · 0.80
forwardMethod · 0.80
test_topkFunction · 0.80

Calls 4

num_nodesMethod · 0.95
DGLErrorClass · 0.85
formatMethod · 0.80
copy_toMethod · 0.45

Tested by 14

test_topkFunction · 0.64
test_to_deviceFunction · 0.64
test_pin_memory_Function · 0.64
test_topologyFunction · 0.64
test_batching_batchedFunction · 0.64
test_empty_relationFunction · 0.64
test_batch_unbatchFunction · 0.64
test_batch_unbatch1Function · 0.64
test_set_batch_infoFunction · 0.64
test_remove_edgesFunction · 0.64
test_remove_nodesFunction · 0.64