Load a state dictionary to restore the data loader's state.
(self, state_dict: dict[str, Any])
| 231 | } |
| 232 | |
| 233 | def load_state_dict(self, state_dict: dict[str, Any]) -> None: |
| 234 | """Load a state dictionary to restore the data loader's state.""" |
| 235 | self._excluded_indices = set(state_dict.get("excluded_indices", [])) |
| 236 | # Set epoch to one less than target since reshuffle() increments it |
| 237 | self._epoch = state_dict["epoch"] - 1 |
| 238 | self.reshuffle() |
| 239 | assert self._epoch == state_dict["epoch"] |
| 240 | self.batches_processed = state_dict["batches_processed"] |
| 241 | self._current_iter = None |
| 242 | |
| 243 | def exclude_index(self, index: int) -> None: |
| 244 | """Exclude a dataset index from future iterations. |