| 622 | |
| 623 | class _PrefetchingIter(object): |
| 624 | def __init__(self, dataloader, dataloader_it, num_threads=None): |
| 625 | self.queue = Queue(1) |
| 626 | self.dataloader_it = dataloader_it |
| 627 | self.dataloader = dataloader |
| 628 | self.num_threads = num_threads |
| 629 | |
| 630 | self.use_thread = dataloader.use_prefetch_thread |
| 631 | self.use_alternate_streams = dataloader.use_alternate_streams |
| 632 | self.device = self.dataloader.device |
| 633 | if self.use_alternate_streams and self.device.type == "cuda": |
| 634 | self.stream = torch.cuda.Stream(device=self.device) |
| 635 | else: |
| 636 | self.stream = None |
| 637 | self._shutting_down = False |
| 638 | if self.use_thread: |
| 639 | self._done_event = threading.Event() |
| 640 | thread = threading.Thread( |
| 641 | target=_prefetcher_entry, |
| 642 | args=( |
| 643 | dataloader_it, |
| 644 | dataloader, |
| 645 | self.queue, |
| 646 | num_threads, |
| 647 | self.stream, |
| 648 | self._done_event, |
| 649 | ), |
| 650 | daemon=True, |
| 651 | ) |
| 652 | thread.start() |
| 653 | self.thread = thread |
| 654 | |
| 655 | def __iter__(self): |
| 656 | return self |