| 1183 | |
| 1184 | |
| 1185 | class TrainDataLoaderIter(DataLoaderIter): |
| 1186 | def __init__(self, data_loader, auto_reset=True): |
| 1187 | super().__init__(data_loader) |
| 1188 | self.auto_reset = auto_reset |
| 1189 | |
| 1190 | def __next__(self): |
| 1191 | try: |
| 1192 | batch = next(self._iterator) |
| 1193 | # inputs, labels = self.inputs_labels_from_batch(batch) |
| 1194 | except StopIteration: |
| 1195 | if not self.auto_reset: |
| 1196 | raise |
| 1197 | self._iterator = iter(self.data_loader) |
| 1198 | batch = next(self._iterator) |
| 1199 | # inputs, labels = self.inputs_labels_from_batch(batch) |
| 1200 | |
| 1201 | return batch |
| 1202 | |
| 1203 | |
| 1204 | class ValDataLoaderIter(TrainDataLoaderIter): |