MCPcopy
hub / github.com/dmlc/dgl / EidExcluder

Class EidExcluder

python/dgl/sampling/utils.py:26–105  ·  view source on GitHub ↗

Class that finds the edges whose IDs in parent graph appeared in exclude_eids. The edge IDs can be both CPU and GPU tensors.

Source from the content-addressed store, hash-verified

24
25
26class EidExcluder(object):
27 """Class that finds the edges whose IDs in parent graph appeared in exclude_eids.
28
29 The edge IDs can be both CPU and GPU tensors.
30 """
31
32 def __init__(self, exclude_eids):
33 device = None
34 if isinstance(exclude_eids, Mapping):
35 for _, v in exclude_eids.items():
36 if device is None:
37 device = F.context(v)
38 break
39 else:
40 device = F.context(exclude_eids)
41 self._exclude_eids = None
42 self._filter = None
43
44 if device == F.cpu():
45 # TODO(nv-dlasalle): Once Filter is implemented for the CPU, we
46 # should just use that irregardless of the device.
47 self._exclude_eids = (
48 recursive_apply(exclude_eids, F.zerocopy_to_numpy)
49 if exclude_eids is not None
50 else None
51 )
52 else:
53 self._filter = recursive_apply(exclude_eids, utils.Filter)
54
55 def _find_indices(self, parent_eids):
56 """Find the set of edge indices to remove."""
57 if self._exclude_eids is not None:
58 parent_eids_np = recursive_apply(parent_eids, F.zerocopy_to_numpy)
59 return _locate_eids_to_exclude(parent_eids_np, self._exclude_eids)
60 else:
61 assert self._filter is not None
62 func = lambda x, y: x.find_included_indices(y)
63 return recursive_apply_pair(self._filter, parent_eids, func)
64
65 def __call__(self, frontier, weights=None):
66 parent_eids = frontier.edata[EID]
67 located_eids = self._find_indices(parent_eids)
68
69 if not isinstance(located_eids, Mapping):
70 # (BarclayII) If frontier already has a EID field and located_eids is empty,
71 # the returned graph will keep EID intact. Otherwise, EID will change
72 # to the mapping from the new graph to the old frontier.
73 # So we need to test if located_eids is empty, and do the remapping ourselves.
74 if len(located_eids) > 0:
75 frontier = transforms.remove_edges(
76 frontier, located_eids, store_ids=True
77 )
78 if (
79 weights is not None
80 and weights[0].shape[0] == frontier.num_edges()
81 ):
82 weights[0] = F.gather_row(weights[0], frontier.edata[EID])
83 frontier.edata[EID] = F.gather_row(

Callers 5

sample_laborsFunction · 0.85
sample_neighborsFunction · 0.85
sample_neighbors_fusedFunction · 0.85
sampleMethod · 0.85
sampleMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected