| 84 | |
| 85 | |
| 86 | class WorkQueue(): |
| 87 | def __init__(self, work_fn, num_threads=1): |
| 88 | self.queue = Queue(num_threads) |
| 89 | self.threads = [ |
| 90 | Thread(target=self.thread_fn, args=(work_fn,)) |
| 91 | for _ in range(num_threads) |
| 92 | ] |
| 93 | for thread in self.threads: |
| 94 | thread.start() |
| 95 | |
| 96 | def join(self): |
| 97 | for thread in self.threads: |
| 98 | self.queue.put(None) |
| 99 | for thread in self.threads: |
| 100 | thread.join() |
| 101 | |
| 102 | def thread_fn(self, work_fn): |
| 103 | item = self.queue.get() |
| 104 | while item is not None: |
| 105 | work_fn(item) |
| 106 | item = self.queue.get() |
| 107 | |
| 108 | def put(self, data): |
| 109 | self.queue.put(data) |
| 110 | |
| 111 | |
| 112 | class FeaturePairsDataset(torch.utils.data.Dataset): |