Base class for prefetching iterators. Takes one or more DataIters ( or any class with "reset" and "next" methods) and combine them with prefetching. For example: Parameters ---------- iters : DataIter or list of DataIter one or more DataIters (or any class with "reset" a
| 17 | |
| 18 | |
| 19 | class PrefetchingIter(mx.io.DataIter): |
| 20 | """Base class for prefetching iterators. Takes one or more DataIters ( |
| 21 | or any class with "reset" and "next" methods) and combine them with |
| 22 | prefetching. For example: |
| 23 | |
| 24 | Parameters |
| 25 | ---------- |
| 26 | iters : DataIter or list of DataIter |
| 27 | one or more DataIters (or any class with "reset" and "next" methods) |
| 28 | rename_data : None or list of dict |
| 29 | i-th element is a renaming map for i-th iter, in the form of |
| 30 | {'original_name' : 'new_name'}. Should have one entry for each entry |
| 31 | in iter[i].provide_data |
| 32 | rename_label : None or list of dict |
| 33 | Similar to rename_data |
| 34 | |
| 35 | Examples |
| 36 | -------- |
| 37 | iter = PrefetchingIter([NDArrayIter({'data': X1}), NDArrayIter({'data': X2})], |
| 38 | rename_data=[{'data': 'data1'}, {'data': 'data2'}]) |
| 39 | """ |
| 40 | def __init__(self, iters, rename_data=None, rename_label=None): |
| 41 | super(PrefetchingIter, self).__init__() |
| 42 | if not isinstance(iters, list): |
| 43 | iters = [iters] |
| 44 | self.n_iter = len(iters) |
| 45 | assert self.n_iter ==1, "Our prefetching iter only support 1 DataIter" |
| 46 | self.iters = iters |
| 47 | self.rename_data = rename_data |
| 48 | self.rename_label = rename_label |
| 49 | self.batch_size = len(self.provide_data) * self.provide_data[0][0][1][0] |
| 50 | self.data_ready = [threading.Event() for i in range(self.n_iter)] |
| 51 | self.data_taken = [threading.Event() for i in range(self.n_iter)] |
| 52 | for e in self.data_taken: |
| 53 | e.set() |
| 54 | self.started = True |
| 55 | self.current_batch = [None for _ in range(self.n_iter)] |
| 56 | self.next_batch = [None for _ in range(self.n_iter)] |
| 57 | def prefetch_func(self, i): |
| 58 | """Thread entry""" |
| 59 | while True: |
| 60 | self.data_taken[i].wait() |
| 61 | if not self.started: |
| 62 | break |
| 63 | try: |
| 64 | self.next_batch[i] = self.iters[i].next() |
| 65 | except StopIteration: |
| 66 | self.next_batch[i] = None |
| 67 | self.data_taken[i].clear() |
| 68 | self.data_ready[i].set() |
| 69 | self.prefetch_threads = [threading.Thread(target=prefetch_func, args=[self, i]) \ |
| 70 | for i in range(self.n_iter)] |
| 71 | for thread in self.prefetch_threads: |
| 72 | thread.setDaemon(True) |
| 73 | thread.start() |
| 74 | |
| 75 | def __del__(self): |
| 76 | self.started = False |
no outgoing calls