Return the IDs of the nodes with the given node type that satisfy the given predicate. Parameters ---------- predicate : callable A function of signature ``func(nodes) -> Tensor``. ``nodes`` are :class:`dgl.NodeBatch` objects. Its
(self, predicate, nodes=ALL, ntype=None)
| 5428 | ################################################################# |
| 5429 | |
| 5430 | def filter_nodes(self, predicate, nodes=ALL, ntype=None): |
| 5431 | """Return the IDs of the nodes with the given node type that satisfy |
| 5432 | the given predicate. |
| 5433 | |
| 5434 | Parameters |
| 5435 | ---------- |
| 5436 | predicate : callable |
| 5437 | A function of signature ``func(nodes) -> Tensor``. |
| 5438 | ``nodes`` are :class:`dgl.NodeBatch` objects. |
| 5439 | Its output tensor should be a 1D boolean tensor with |
| 5440 | each element indicating whether the corresponding node in |
| 5441 | the batch satisfies the predicate. |
| 5442 | nodes : node ID(s), optional |
| 5443 | The node(s) for query. The allowed formats are: |
| 5444 | |
| 5445 | - Tensor: A 1D tensor that contains the node(s) for query, whose data type |
| 5446 | and device should be the same as the :py:attr:`idtype` and device of the graph. |
| 5447 | - iterable[int] : Similar to the tensor, but stores node IDs in a sequence |
| 5448 | (e.g. list, tuple, numpy.ndarray). |
| 5449 | |
| 5450 | By default, it considers all nodes. |
| 5451 | ntype : str, optional |
| 5452 | The node type for query. If the graph has multiple node types, one must |
| 5453 | specify the argument. Otherwise, it can be omitted. |
| 5454 | |
| 5455 | Returns |
| 5456 | ------- |
| 5457 | Tensor |
| 5458 | A 1D tensor that contains the ID(s) of the node(s) that satisfy the predicate. |
| 5459 | |
| 5460 | Examples |
| 5461 | -------- |
| 5462 | |
| 5463 | The following example uses PyTorch backend. |
| 5464 | |
| 5465 | >>> import dgl |
| 5466 | >>> import torch |
| 5467 | |
| 5468 | Define a predicate function. |
| 5469 | |
| 5470 | >>> def nodes_with_feature_one(nodes): |
| 5471 | ... # Whether a node has feature 1 |
| 5472 | ... return (nodes.data['h'] == 1.).squeeze(1) |
| 5473 | |
| 5474 | Filter nodes for a homogeneous graph. |
| 5475 | |
| 5476 | >>> g = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3]))) |
| 5477 | >>> g.ndata['h'] = torch.tensor([[0.], [1.], [1.], [0.]]) |
| 5478 | >>> print(g.filter_nodes(nodes_with_feature_one)) |
| 5479 | tensor([1, 2]) |
| 5480 | |
| 5481 | Filter on nodes with IDs 0 and 1 |
| 5482 | |
| 5483 | >>> print(g.filter_nodes(nodes_with_feature_one, nodes=torch.tensor([0, 1]))) |
| 5484 | tensor([1]) |
| 5485 | |
| 5486 | Filter nodes for a heterogeneous graph. |
| 5487 |