Index class that can be easily converted to list/tensor.
| 29 | |
| 30 | |
| 31 | class Index(object): |
| 32 | """Index class that can be easily converted to list/tensor.""" |
| 33 | |
| 34 | def __init__(self, data, dtype="int64"): |
| 35 | assert dtype in ["int32", "int64"] |
| 36 | self.dtype = dtype |
| 37 | self._initialize_data(data) |
| 38 | |
| 39 | def _initialize_data(self, data): |
| 40 | self._pydata = None # a numpy type data |
| 41 | self._user_tensor_data = dict() # dictionary of user tensors |
| 42 | self._dgl_tensor_data = None # a dgl ndarray |
| 43 | self._slice_data = None # a slice type data |
| 44 | self._dispatch(data) |
| 45 | |
| 46 | def __iter__(self): |
| 47 | for i in self.tonumpy(): |
| 48 | yield int(i) |
| 49 | |
| 50 | def __len__(self): |
| 51 | if self._slice_data is not None: |
| 52 | slc = self._slice_data |
| 53 | return slc.stop - slc.start |
| 54 | elif self._pydata is not None: |
| 55 | return len(self._pydata) |
| 56 | elif len(self._user_tensor_data) > 0: |
| 57 | data = next(iter(self._user_tensor_data.values())) |
| 58 | return len(data) |
| 59 | else: |
| 60 | return len(self._dgl_tensor_data) |
| 61 | |
| 62 | def __getitem__(self, i): |
| 63 | return int(self.tonumpy()[i]) |
| 64 | |
| 65 | def _dispatch(self, data): |
| 66 | """Store data based on its type.""" |
| 67 | if F.is_tensor(data): |
| 68 | if F.dtype(data) != F.data_type_dict[self.dtype]: |
| 69 | raise InconsistentDtypeException( |
| 70 | "Index data specified as %s, but got: %s" |
| 71 | % (self.dtype, F.reverse_data_type_dict[F.dtype(data)]) |
| 72 | ) |
| 73 | if len(F.shape(data)) > 1: |
| 74 | raise InconsistentDtypeException( |
| 75 | "Index data must be 1D int32/int64 vector,\ |
| 76 | but got shape: %s" |
| 77 | % str(F.shape(data)) |
| 78 | ) |
| 79 | if len(F.shape(data)) == 0: |
| 80 | # a tensor of one int |
| 81 | self._dispatch(int(data)) |
| 82 | else: |
| 83 | self._user_tensor_data[F.context(data)] = data |
| 84 | elif isinstance(data, nd.NDArray): |
| 85 | if not (data.dtype == self.dtype and len(data.shape) == 1): |
| 86 | raise InconsistentDtypeException( |
| 87 | "Index data must be 1D %s vector, but got: %s" |
| 88 | % (self.dtype, data.dtype) |
no outgoing calls
no test coverage detected