| 12 | |
| 13 | |
| 14 | class _LazyIndex(object): |
| 15 | def __init__(self, index): |
| 16 | if isinstance(index, list): |
| 17 | self._indices = index |
| 18 | else: |
| 19 | self._indices = [index] |
| 20 | |
| 21 | def __len__(self): |
| 22 | return len(self._indices[-1]) |
| 23 | |
| 24 | def slice(self, index): |
| 25 | """Create a new _LazyIndex object sliced by the given index tensor.""" |
| 26 | # if our indices are in the same context, lets just slice now and free |
| 27 | # memory, otherwise do nothing until we have to |
| 28 | if F.context(self._indices[-1]) == F.context(index): |
| 29 | return _LazyIndex( |
| 30 | self._indices[:-1] + [F.gather_row(self._indices[-1], index)] |
| 31 | ) |
| 32 | return _LazyIndex(self._indices + [index]) |
| 33 | |
| 34 | def flatten(self): |
| 35 | """Evaluate the chain of indices, and return a single index tensor.""" |
| 36 | flat_index = self._indices[0] |
| 37 | # here we actually need to resolve it |
| 38 | for index in self._indices[1:]: |
| 39 | if F.context(index) != F.context(flat_index): |
| 40 | index = F.copy_to(index, F.context(flat_index)) |
| 41 | flat_index = F.gather_row(flat_index, index) |
| 42 | return flat_index |
| 43 | |
| 44 | def record_stream(self, stream): |
| 45 | """Record stream for index. |
| 46 | |
| 47 | Parameters |
| 48 | ---------- |
| 49 | stream : torch.cuda.Stream. |
| 50 | """ |
| 51 | for index in self._indices: |
| 52 | if F.context(index) != F.cpu(): |
| 53 | index.record_stream(stream) |
| 54 | |
| 55 | |
| 56 | class LazyFeature(object): |