Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors. When the dataset is on the GPU, this significantly reduces the overhead.
| 189 | |
| 190 | |
| 191 | class TensorizedDataset(torch.utils.data.IterableDataset): |
| 192 | """Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors. |
| 193 | When the dataset is on the GPU, this significantly reduces the overhead. |
| 194 | """ |
| 195 | |
| 196 | def __init__( |
| 197 | self, indices, batch_size, drop_last, shuffle, use_shared_memory |
| 198 | ): |
| 199 | if isinstance(indices, Mapping): |
| 200 | self._mapping_keys = list(indices.keys()) |
| 201 | self._device = next(iter(indices.values())).device |
| 202 | self._id_tensor = _get_id_tensor_from_mapping( |
| 203 | indices, self._device, self._mapping_keys |
| 204 | ) |
| 205 | else: |
| 206 | self._id_tensor = indices |
| 207 | self._device = indices.device |
| 208 | self._mapping_keys = None |
| 209 | # Use a shared memory array to permute indices for shuffling. This is to make sure that |
| 210 | # the worker processes can see it when persistent_workers=True, where self._indices |
| 211 | # would not be duplicated every epoch. |
| 212 | self._indices = torch.arange( |
| 213 | self._id_tensor.shape[0], dtype=torch.int64 |
| 214 | ) |
| 215 | if use_shared_memory: |
| 216 | self._indices.share_memory_() |
| 217 | self.batch_size = batch_size |
| 218 | self.drop_last = drop_last |
| 219 | self._shuffle = shuffle |
| 220 | |
| 221 | def shuffle(self): |
| 222 | """Shuffle the dataset.""" |
| 223 | np.random.shuffle(self._indices.numpy()) |
| 224 | |
| 225 | def __iter__(self): |
| 226 | indices = _divide_by_worker( |
| 227 | self._indices, self.batch_size, self.drop_last |
| 228 | ) |
| 229 | id_tensor = self._id_tensor[indices] |
| 230 | return _TensorizedDatasetIter( |
| 231 | id_tensor, |
| 232 | self.batch_size, |
| 233 | self.drop_last, |
| 234 | self._mapping_keys, |
| 235 | self._shuffle, |
| 236 | ) |
| 237 | |
| 238 | def __len__(self): |
| 239 | num_samples = self._id_tensor.shape[0] |
| 240 | return ( |
| 241 | num_samples + (0 if self.drop_last else (self.batch_size - 1)) |
| 242 | ) // self.batch_size |
| 243 | |
| 244 | |
| 245 | def _decompose_one_dimension(length, world_size, rank, drop_last): |
no outgoing calls
no test coverage detected