Store data based on its type.
(self, data)
| 63 | return int(self.tonumpy()[i]) |
| 64 | |
| 65 | def _dispatch(self, data): |
| 66 | """Store data based on its type.""" |
| 67 | if F.is_tensor(data): |
| 68 | if F.dtype(data) != F.data_type_dict[self.dtype]: |
| 69 | raise InconsistentDtypeException( |
| 70 | "Index data specified as %s, but got: %s" |
| 71 | % (self.dtype, F.reverse_data_type_dict[F.dtype(data)]) |
| 72 | ) |
| 73 | if len(F.shape(data)) > 1: |
| 74 | raise InconsistentDtypeException( |
| 75 | "Index data must be 1D int32/int64 vector,\ |
| 76 | but got shape: %s" |
| 77 | % str(F.shape(data)) |
| 78 | ) |
| 79 | if len(F.shape(data)) == 0: |
| 80 | # a tensor of one int |
| 81 | self._dispatch(int(data)) |
| 82 | else: |
| 83 | self._user_tensor_data[F.context(data)] = data |
| 84 | elif isinstance(data, nd.NDArray): |
| 85 | if not (data.dtype == self.dtype and len(data.shape) == 1): |
| 86 | raise InconsistentDtypeException( |
| 87 | "Index data must be 1D %s vector, but got: %s" |
| 88 | % (self.dtype, data.dtype) |
| 89 | ) |
| 90 | self._dgl_tensor_data = data |
| 91 | elif isinstance(data, slice): |
| 92 | # save it in the _pydata temporarily; materialize it if `tonumpy` is called |
| 93 | assert ( |
| 94 | data.step == 1 or data.step is None |
| 95 | ), "step for slice type must be 1" |
| 96 | self._slice_data = slice(data.start, data.stop) |
| 97 | else: |
| 98 | try: |
| 99 | data = np.asarray(data, dtype=self.dtype) |
| 100 | except Exception: # pylint: disable=broad-except |
| 101 | raise DGLError("Error index data: %s" % str(data)) |
| 102 | if data.ndim == 0: # scalar array |
| 103 | data = np.expand_dims(data, 0) |
| 104 | elif data.ndim != 1: |
| 105 | raise DGLError( |
| 106 | "Index data must be 1D int64 vector," |
| 107 | " but got: %s" % str(data) |
| 108 | ) |
| 109 | self._pydata = data |
| 110 | self._user_tensor_data[F.cpu()] = F.zerocopy_from_numpy( |
| 111 | self._pydata |
| 112 | ) |
| 113 | |
| 114 | def tonumpy(self): |
| 115 | """Convert to a numpy ndarray.""" |
no test coverage detected