Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors. When the dataset is on the GPU, this significantly reduces the overhead. This class additionally saves the index tensor in shared memory and therefore avoids duplicating the same index tensor during shufflin
| 253 | |
| 254 | |
| 255 | class DDPTensorizedDataset(torch.utils.data.IterableDataset): |
| 256 | """Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors. |
| 257 | When the dataset is on the GPU, this significantly reduces the overhead. |
| 258 | |
| 259 | This class additionally saves the index tensor in shared memory and therefore |
| 260 | avoids duplicating the same index tensor during shuffling. |
| 261 | """ |
| 262 | |
| 263 | def __init__(self, indices, batch_size, drop_last, ddp_seed, shuffle): |
| 264 | if isinstance(indices, Mapping): |
| 265 | self._mapping_keys = list(indices.keys()) |
| 266 | len_indices = sum(len(v) for v in indices.values()) |
| 267 | else: |
| 268 | self._mapping_keys = None |
| 269 | len_indices = len(indices) |
| 270 | |
| 271 | self.rank = dist.get_rank() |
| 272 | self.num_replicas = dist.get_world_size() |
| 273 | self.seed = ddp_seed |
| 274 | self.epoch = 0 |
| 275 | self.batch_size = batch_size |
| 276 | self.drop_last = drop_last |
| 277 | self._shuffle = shuffle |
| 278 | ( |
| 279 | self.local_lower_bound, |
| 280 | self.local_upper_bound, |
| 281 | ) = _decompose_one_dimension( |
| 282 | len_indices, self.num_replicas, self.rank, drop_last |
| 283 | ) |
| 284 | self.num_samples = self.local_upper_bound - self.local_lower_bound |
| 285 | self.local_num_indices = self.num_samples |
| 286 | if self.local_upper_bound > len_indices: |
| 287 | assert not drop_last |
| 288 | self.local_num_indices = len_indices - self.local_lower_bound |
| 289 | |
| 290 | if isinstance(indices, Mapping): |
| 291 | self._id_tensor = _split_to_local_id_tensor_from_mapping( |
| 292 | indices, |
| 293 | self._mapping_keys, |
| 294 | self.local_lower_bound, |
| 295 | self.local_upper_bound, |
| 296 | ) |
| 297 | else: |
| 298 | self._id_tensor = _split_to_local_id_tensor( |
| 299 | indices, self.local_lower_bound, self.local_upper_bound |
| 300 | ) |
| 301 | self._device = self._id_tensor.device |
| 302 | # padding self._indices when drop_last = False (self._indices always on cpu) |
| 303 | self._indices = torch.empty(self.num_samples, dtype=torch.int64) |
| 304 | torch.arange( |
| 305 | self.local_num_indices, out=self._indices[: self.local_num_indices] |
| 306 | ) |
| 307 | if not drop_last: |
| 308 | torch.arange( |
| 309 | self.num_samples - self.local_num_indices, |
| 310 | out=self._indices[self.local_num_indices :], |
| 311 | ) |
| 312 | assert len(self._id_tensor) == self.num_samples |
no outgoing calls
no test coverage detected