MCPcopy
hub / github.com/dmlc/dgl / filter_nodes

Method filter_nodes

python/dgl/heterograph.py:5430–5512  ·  view source on GitHub ↗

Return the IDs of the nodes with the given node type that satisfy the given predicate. Parameters ---------- predicate : callable A function of signature ``func(nodes) -> Tensor``. ``nodes`` are :class:`dgl.NodeBatch` objects. Its

(self, predicate, nodes=ALL, ntype=None)

Source from the content-addressed store, hash-verified

5428 #################################################################
5429
5430 def filter_nodes(self, predicate, nodes=ALL, ntype=None):
5431 """Return the IDs of the nodes with the given node type that satisfy
5432 the given predicate.
5433
5434 Parameters
5435 ----------
5436 predicate : callable
5437 A function of signature ``func(nodes) -> Tensor``.
5438 ``nodes`` are :class:`dgl.NodeBatch` objects.
5439 Its output tensor should be a 1D boolean tensor with
5440 each element indicating whether the corresponding node in
5441 the batch satisfies the predicate.
5442 nodes : node ID(s), optional
5443 The node(s) for query. The allowed formats are:
5444
5445 - Tensor: A 1D tensor that contains the node(s) for query, whose data type
5446 and device should be the same as the :py:attr:`idtype` and device of the graph.
5447 - iterable[int] : Similar to the tensor, but stores node IDs in a sequence
5448 (e.g. list, tuple, numpy.ndarray).
5449
5450 By default, it considers all nodes.
5451 ntype : str, optional
5452 The node type for query. If the graph has multiple node types, one must
5453 specify the argument. Otherwise, it can be omitted.
5454
5455 Returns
5456 -------
5457 Tensor
5458 A 1D tensor that contains the ID(s) of the node(s) that satisfy the predicate.
5459
5460 Examples
5461 --------
5462
5463 The following example uses PyTorch backend.
5464
5465 >>> import dgl
5466 >>> import torch
5467
5468 Define a predicate function.
5469
5470 >>> def nodes_with_feature_one(nodes):
5471 ... # Whether a node has feature 1
5472 ... return (nodes.data['h'] == 1.).squeeze(1)
5473
5474 Filter nodes for a homogeneous graph.
5475
5476 >>> g = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])))
5477 >>> g.ndata['h'] = torch.tensor([[0.], [1.], [1.], [0.]])
5478 >>> print(g.filter_nodes(nodes_with_feature_one))
5479 tensor([1, 2])
5480
5481 Filter on nodes with IDs 0 and 1
5482
5483 >>> print(g.filter_nodes(nodes_with_feature_one, nodes=torch.tensor([0, 1])))
5484 tensor([1])
5485
5486 Filter nodes for a heterogeneous graph.
5487

Callers 4

test_types_in_functionFunction · 0.80
test_graph_filterFunction · 0.80
decodeMethod · 0.80
forwardMethod · 0.80

Calls 7

nodesMethod · 0.95
has_nodesMethod · 0.95
local_scopeMethod · 0.95
apply_nodesMethod · 0.95
is_allFunction · 0.85
DGLErrorClass · 0.85
predicateFunction · 0.85

Tested by 2

test_types_in_functionFunction · 0.64
test_graph_filterFunction · 0.64