(
self, indices, batch_size, drop_last, shuffle, use_shared_memory
)
| 194 | """ |
| 195 | |
| 196 | def __init__( |
| 197 | self, indices, batch_size, drop_last, shuffle, use_shared_memory |
| 198 | ): |
| 199 | if isinstance(indices, Mapping): |
| 200 | self._mapping_keys = list(indices.keys()) |
| 201 | self._device = next(iter(indices.values())).device |
| 202 | self._id_tensor = _get_id_tensor_from_mapping( |
| 203 | indices, self._device, self._mapping_keys |
| 204 | ) |
| 205 | else: |
| 206 | self._id_tensor = indices |
| 207 | self._device = indices.device |
| 208 | self._mapping_keys = None |
| 209 | # Use a shared memory array to permute indices for shuffling. This is to make sure that |
| 210 | # the worker processes can see it when persistent_workers=True, where self._indices |
| 211 | # would not be duplicated every epoch. |
| 212 | self._indices = torch.arange( |
| 213 | self._id_tensor.shape[0], dtype=torch.int64 |
| 214 | ) |
| 215 | if use_shared_memory: |
| 216 | self._indices.share_memory_() |
| 217 | self.batch_size = batch_size |
| 218 | self.drop_last = drop_last |
| 219 | self._shuffle = shuffle |
| 220 | |
| 221 | def shuffle(self): |
| 222 | """Shuffle the dataset.""" |
nothing calls this directly
no test coverage detected