MCPcopy
hub / github.com/dmlc/dgl / _LazyIndex

Class _LazyIndex

python/dgl/frame.py:14–53  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

12
13
14class _LazyIndex(object):
15 def __init__(self, index):
16 if isinstance(index, list):
17 self._indices = index
18 else:
19 self._indices = [index]
20
21 def __len__(self):
22 return len(self._indices[-1])
23
24 def slice(self, index):
25 """Create a new _LazyIndex object sliced by the given index tensor."""
26 # if our indices are in the same context, lets just slice now and free
27 # memory, otherwise do nothing until we have to
28 if F.context(self._indices[-1]) == F.context(index):
29 return _LazyIndex(
30 self._indices[:-1] + [F.gather_row(self._indices[-1], index)]
31 )
32 return _LazyIndex(self._indices + [index])
33
34 def flatten(self):
35 """Evaluate the chain of indices, and return a single index tensor."""
36 flat_index = self._indices[0]
37 # here we actually need to resolve it
38 for index in self._indices[1:]:
39 if F.context(index) != F.context(flat_index):
40 index = F.copy_to(index, F.context(flat_index))
41 flat_index = F.gather_row(flat_index, index)
42 return flat_index
43
44 def record_stream(self, stream):
45 """Record stream for index.
46
47 Parameters
48 ----------
49 stream : torch.cuda.Stream.
50 """
51 for index in self._indices:
52 if F.context(index) != F.cpu():
53 index.record_stream(stream)
54
55
56class LazyFeature(object):

Callers 2

sliceMethod · 0.85
subcolumnMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected