(self, indices, batch_size, drop_last, ddp_seed, shuffle)
| 261 | """ |
| 262 | |
| 263 | def __init__(self, indices, batch_size, drop_last, ddp_seed, shuffle): |
| 264 | if isinstance(indices, Mapping): |
| 265 | self._mapping_keys = list(indices.keys()) |
| 266 | len_indices = sum(len(v) for v in indices.values()) |
| 267 | else: |
| 268 | self._mapping_keys = None |
| 269 | len_indices = len(indices) |
| 270 | |
| 271 | self.rank = dist.get_rank() |
| 272 | self.num_replicas = dist.get_world_size() |
| 273 | self.seed = ddp_seed |
| 274 | self.epoch = 0 |
| 275 | self.batch_size = batch_size |
| 276 | self.drop_last = drop_last |
| 277 | self._shuffle = shuffle |
| 278 | ( |
| 279 | self.local_lower_bound, |
| 280 | self.local_upper_bound, |
| 281 | ) = _decompose_one_dimension( |
| 282 | len_indices, self.num_replicas, self.rank, drop_last |
| 283 | ) |
| 284 | self.num_samples = self.local_upper_bound - self.local_lower_bound |
| 285 | self.local_num_indices = self.num_samples |
| 286 | if self.local_upper_bound > len_indices: |
| 287 | assert not drop_last |
| 288 | self.local_num_indices = len_indices - self.local_lower_bound |
| 289 | |
| 290 | if isinstance(indices, Mapping): |
| 291 | self._id_tensor = _split_to_local_id_tensor_from_mapping( |
| 292 | indices, |
| 293 | self._mapping_keys, |
| 294 | self.local_lower_bound, |
| 295 | self.local_upper_bound, |
| 296 | ) |
| 297 | else: |
| 298 | self._id_tensor = _split_to_local_id_tensor( |
| 299 | indices, self.local_lower_bound, self.local_upper_bound |
| 300 | ) |
| 301 | self._device = self._id_tensor.device |
| 302 | # padding self._indices when drop_last = False (self._indices always on cpu) |
| 303 | self._indices = torch.empty(self.num_samples, dtype=torch.int64) |
| 304 | torch.arange( |
| 305 | self.local_num_indices, out=self._indices[: self.local_num_indices] |
| 306 | ) |
| 307 | if not drop_last: |
| 308 | torch.arange( |
| 309 | self.num_samples - self.local_num_indices, |
| 310 | out=self._indices[self.local_num_indices :], |
| 311 | ) |
| 312 | assert len(self._id_tensor) == self.num_samples |
| 313 | |
| 314 | def shuffle(self): |
| 315 | """Shuffles the dataset.""" |
nothing calls this directly
no test coverage detected