Return the IDs of the edges with the given edge type that satisfy the given predicate. Parameters ---------- predicate : callable A function of signature ``func(edges) -> Tensor``. ``edges`` are :class:`dgl.EdgeBatch` objects. Its
(self, predicate, edges=ALL, etype=None)
| 5512 | return F.boolean_mask(v, F.gather_row(mask, v)) |
| 5513 | |
| 5514 | def filter_edges(self, predicate, edges=ALL, etype=None): |
| 5515 | """Return the IDs of the edges with the given edge type that satisfy |
| 5516 | the given predicate. |
| 5517 | |
| 5518 | Parameters |
| 5519 | ---------- |
| 5520 | predicate : callable |
| 5521 | A function of signature ``func(edges) -> Tensor``. |
| 5522 | ``edges`` are :class:`dgl.EdgeBatch` objects. |
| 5523 | Its output tensor should be a 1D boolean tensor with |
| 5524 | each element indicating whether the corresponding edge in |
| 5525 | the batch satisfies the predicate. |
| 5526 | edges : edges |
| 5527 | The edges to send and receive messages on. The allowed input formats are: |
| 5528 | |
| 5529 | * ``int``: A single edge ID. |
| 5530 | * Int Tensor: Each element is an edge ID. The tensor must have the same device type |
| 5531 | and ID data type as the graph's. |
| 5532 | * iterable[int]: Each element is an edge ID. |
| 5533 | * (Tensor, Tensor): The node-tensors format where the i-th elements |
| 5534 | of the two tensors specify an edge. |
| 5535 | * (iterable[int], iterable[int]): Similar to the node-tensors format but |
| 5536 | stores edge endpoints in python iterables. |
| 5537 | |
| 5538 | By default, it considers all the edges. |
| 5539 | etype : str or (str, str, str), optional |
| 5540 | The type name of the edges. The allowed type name formats are: |
| 5541 | |
| 5542 | * ``(str, str, str)`` for source node type, edge type and destination node type. |
| 5543 | * or one ``str`` edge type name if the name can uniquely identify a |
| 5544 | triplet format in the graph. |
| 5545 | |
| 5546 | Can be omitted if the graph has only one type of edges. |
| 5547 | |
| 5548 | Returns |
| 5549 | ------- |
| 5550 | Tensor |
| 5551 | A 1D tensor that contains the ID(s) of the edge(s) that satisfy the predicate. |
| 5552 | |
| 5553 | Examples |
| 5554 | -------- |
| 5555 | |
| 5556 | The following example uses PyTorch backend. |
| 5557 | |
| 5558 | >>> import dgl |
| 5559 | >>> import torch |
| 5560 | |
| 5561 | Define a predicate function. |
| 5562 | |
| 5563 | >>> def edges_with_feature_one(edges): |
| 5564 | ... # Whether an edge has feature 1 |
| 5565 | ... return (edges.data['h'] == 1.).squeeze(1) |
| 5566 | |
| 5567 | Filter edges for a homogeneous graph. |
| 5568 | |
| 5569 | >>> g = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3]))) |
| 5570 | >>> g.edata['h'] = torch.tensor([[0.], [1.], [1.]]) |
| 5571 | >>> print(g.filter_edges(edges_with_feature_one)) |