MCPcopy
hub / github.com/msracver/Deformable-ConvNets / PrefetchingIter

Class PrefetchingIter

lib/utils/PrefetchingIter.py:19–145  ·  view source on GitHub ↗

Base class for prefetching iterators. Takes one or more DataIters ( or any class with "reset" and "next" methods) and combine them with prefetching. For example: Parameters ---------- iters : DataIter or list of DataIter one or more DataIters (or any class with "reset" a

Source from the content-addressed store, hash-verified

17
18
19class PrefetchingIter(mx.io.DataIter):
20 """Base class for prefetching iterators. Takes one or more DataIters (
21 or any class with "reset" and "next" methods) and combine them with
22 prefetching. For example:
23
24 Parameters
25 ----------
26 iters : DataIter or list of DataIter
27 one or more DataIters (or any class with "reset" and "next" methods)
28 rename_data : None or list of dict
29 i-th element is a renaming map for i-th iter, in the form of
30 {'original_name' : 'new_name'}. Should have one entry for each entry
31 in iter[i].provide_data
32 rename_label : None or list of dict
33 Similar to rename_data
34
35 Examples
36 --------
37 iter = PrefetchingIter([NDArrayIter({'data': X1}), NDArrayIter({'data': X2})],
38 rename_data=[{'data': 'data1'}, {'data': 'data2'}])
39 """
40 def __init__(self, iters, rename_data=None, rename_label=None):
41 super(PrefetchingIter, self).__init__()
42 if not isinstance(iters, list):
43 iters = [iters]
44 self.n_iter = len(iters)
45 assert self.n_iter ==1, "Our prefetching iter only support 1 DataIter"
46 self.iters = iters
47 self.rename_data = rename_data
48 self.rename_label = rename_label
49 self.batch_size = len(self.provide_data) * self.provide_data[0][0][1][0]
50 self.data_ready = [threading.Event() for i in range(self.n_iter)]
51 self.data_taken = [threading.Event() for i in range(self.n_iter)]
52 for e in self.data_taken:
53 e.set()
54 self.started = True
55 self.current_batch = [None for _ in range(self.n_iter)]
56 self.next_batch = [None for _ in range(self.n_iter)]
57 def prefetch_func(self, i):
58 """Thread entry"""
59 while True:
60 self.data_taken[i].wait()
61 if not self.started:
62 break
63 try:
64 self.next_batch[i] = self.iters[i].next()
65 except StopIteration:
66 self.next_batch[i] = None
67 self.data_taken[i].clear()
68 self.data_ready[i].set()
69 self.prefetch_threads = [threading.Thread(target=prefetch_func, args=[self, i]) \
70 for i in range(self.n_iter)]
71 for thread in self.prefetch_threads:
72 thread.setDaemon(True)
73 thread.start()
74
75 def __del__(self):
76 self.started = False

Callers 15

train_netFunction · 0.90
train_rpnFunction · 0.90
train_rcnnFunction · 0.90
generate_proposalsFunction · 0.90
pred_evalFunction · 0.90
train_netFunction · 0.90
train_rpnFunction · 0.90
train_rcnnFunction · 0.90
generate_proposalsFunction · 0.90
pred_evalFunction · 0.90
train_netFunction · 0.90
train_rcnnFunction · 0.90

Calls

no outgoing calls

Tested by 6

generate_proposalsFunction · 0.72
pred_evalFunction · 0.72
generate_proposalsFunction · 0.72
pred_evalFunction · 0.72
pred_evalFunction · 0.72
pred_evalFunction · 0.72