Class that finds the edges whose IDs in parent graph appeared in exclude_eids. The edge IDs can be both CPU and GPU tensors.
| 24 | |
| 25 | |
| 26 | class EidExcluder(object): |
| 27 | """Class that finds the edges whose IDs in parent graph appeared in exclude_eids. |
| 28 | |
| 29 | The edge IDs can be both CPU and GPU tensors. |
| 30 | """ |
| 31 | |
| 32 | def __init__(self, exclude_eids): |
| 33 | device = None |
| 34 | if isinstance(exclude_eids, Mapping): |
| 35 | for _, v in exclude_eids.items(): |
| 36 | if device is None: |
| 37 | device = F.context(v) |
| 38 | break |
| 39 | else: |
| 40 | device = F.context(exclude_eids) |
| 41 | self._exclude_eids = None |
| 42 | self._filter = None |
| 43 | |
| 44 | if device == F.cpu(): |
| 45 | # TODO(nv-dlasalle): Once Filter is implemented for the CPU, we |
| 46 | # should just use that irregardless of the device. |
| 47 | self._exclude_eids = ( |
| 48 | recursive_apply(exclude_eids, F.zerocopy_to_numpy) |
| 49 | if exclude_eids is not None |
| 50 | else None |
| 51 | ) |
| 52 | else: |
| 53 | self._filter = recursive_apply(exclude_eids, utils.Filter) |
| 54 | |
| 55 | def _find_indices(self, parent_eids): |
| 56 | """Find the set of edge indices to remove.""" |
| 57 | if self._exclude_eids is not None: |
| 58 | parent_eids_np = recursive_apply(parent_eids, F.zerocopy_to_numpy) |
| 59 | return _locate_eids_to_exclude(parent_eids_np, self._exclude_eids) |
| 60 | else: |
| 61 | assert self._filter is not None |
| 62 | func = lambda x, y: x.find_included_indices(y) |
| 63 | return recursive_apply_pair(self._filter, parent_eids, func) |
| 64 | |
| 65 | def __call__(self, frontier, weights=None): |
| 66 | parent_eids = frontier.edata[EID] |
| 67 | located_eids = self._find_indices(parent_eids) |
| 68 | |
| 69 | if not isinstance(located_eids, Mapping): |
| 70 | # (BarclayII) If frontier already has a EID field and located_eids is empty, |
| 71 | # the returned graph will keep EID intact. Otherwise, EID will change |
| 72 | # to the mapping from the new graph to the old frontier. |
| 73 | # So we need to test if located_eids is empty, and do the remapping ourselves. |
| 74 | if len(located_eids) > 0: |
| 75 | frontier = transforms.remove_edges( |
| 76 | frontier, located_eids, store_ids=True |
| 77 | ) |
| 78 | if ( |
| 79 | weights is not None |
| 80 | and weights[0].shape[0] == frontier.num_edges() |
| 81 | ): |
| 82 | weights[0] = F.gather_row(weights[0], frontier.edata[EID]) |
| 83 | frontier.edata[EID] = F.gather_row( |
no outgoing calls
no test coverage detected