this is really hacky and I'm not proud of it, but there doesn't seem to be a better way in PyTorch to just create an infinite dataloader?
| 572 | return train_dataset, test_dataset |
| 573 | |
| 574 | class InfiniteDataLoader: |
| 575 | """ |
| 576 | this is really hacky and I'm not proud of it, but there doesn't seem to be |
| 577 | a better way in PyTorch to just create an infinite dataloader? |
| 578 | """ |
| 579 | |
| 580 | def __init__(self, dataset, **kwargs): |
| 581 | train_sampler = torch.utils.data.RandomSampler(dataset, replacement=True, num_samples=int(1e10)) |
| 582 | self.train_loader = DataLoader(dataset, sampler=train_sampler, **kwargs) |
| 583 | self.data_iter = iter(self.train_loader) |
| 584 | |
| 585 | def next(self): |
| 586 | try: |
| 587 | batch = next(self.data_iter) |
| 588 | except StopIteration: # this will technically only happen after 1e10 samples... (i.e. basically never) |
| 589 | self.data_iter = iter(self.train_loader) |
| 590 | batch = next(self.data_iter) |
| 591 | return batch |
| 592 | |
| 593 | # ----------------------------------------------------------------------------- |
| 594 | if __name__ == '__main__': |