MCPcopy
hub / github.com/dmlc/dgl / filter_edges

Method filter_edges

python/dgl/heterograph.py:5514–5629  ·  view source on GitHub ↗

Return the IDs of the edges with the given edge type that satisfy the given predicate. Parameters ---------- predicate : callable A function of signature ``func(edges) -> Tensor``. ``edges`` are :class:`dgl.EdgeBatch` objects. Its

(self, predicate, edges=ALL, etype=None)

Source from the content-addressed store, hash-verified

5512 return F.boolean_mask(v, F.gather_row(mask, v))
5513
5514 def filter_edges(self, predicate, edges=ALL, etype=None):
5515 """Return the IDs of the edges with the given edge type that satisfy
5516 the given predicate.
5517
5518 Parameters
5519 ----------
5520 predicate : callable
5521 A function of signature ``func(edges) -> Tensor``.
5522 ``edges`` are :class:`dgl.EdgeBatch` objects.
5523 Its output tensor should be a 1D boolean tensor with
5524 each element indicating whether the corresponding edge in
5525 the batch satisfies the predicate.
5526 edges : edges
5527 The edges to send and receive messages on. The allowed input formats are:
5528
5529 * ``int``: A single edge ID.
5530 * Int Tensor: Each element is an edge ID. The tensor must have the same device type
5531 and ID data type as the graph's.
5532 * iterable[int]: Each element is an edge ID.
5533 * (Tensor, Tensor): The node-tensors format where the i-th elements
5534 of the two tensors specify an edge.
5535 * (iterable[int], iterable[int]): Similar to the node-tensors format but
5536 stores edge endpoints in python iterables.
5537
5538 By default, it considers all the edges.
5539 etype : str or (str, str, str), optional
5540 The type name of the edges. The allowed type name formats are:
5541
5542 * ``(str, str, str)`` for source node type, edge type and destination node type.
5543 * or one ``str`` edge type name if the name can uniquely identify a
5544 triplet format in the graph.
5545
5546 Can be omitted if the graph has only one type of edges.
5547
5548 Returns
5549 -------
5550 Tensor
5551 A 1D tensor that contains the ID(s) of the edge(s) that satisfy the predicate.
5552
5553 Examples
5554 --------
5555
5556 The following example uses PyTorch backend.
5557
5558 >>> import dgl
5559 >>> import torch
5560
5561 Define a predicate function.
5562
5563 >>> def edges_with_feature_one(edges):
5564 ... # Whether an edge has feature 1
5565 ... return (edges.data['h'] == 1.).squeeze(1)
5566
5567 Filter edges for a homogeneous graph.
5568
5569 >>> g = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])))
5570 >>> g.edata['h'] = torch.tensor([[0.], [1.], [1.]])
5571 >>> print(g.filter_edges(edges_with_feature_one))

Callers 2

test_types_in_functionFunction · 0.80
test_graph_filterFunction · 0.80

Calls 10

to_canonical_etypeMethod · 0.95
has_nodesMethod · 0.95
num_edgesMethod · 0.95
local_scopeMethod · 0.95
apply_edgesMethod · 0.95
edge_idsMethod · 0.95
is_allFunction · 0.85
DGLErrorClass · 0.85
predicateFunction · 0.85
formatMethod · 0.80

Tested by 2

test_types_in_functionFunction · 0.64
test_graph_filterFunction · 0.64