Enqueue datapoints from a DataFlow to a TF queue. And the model receives dequeued tensors.
| 187 | |
| 188 | |
| 189 | class QueueInput(FeedfreeInput): |
| 190 | """ Enqueue datapoints from a DataFlow to a TF queue. |
| 191 | And the model receives dequeued tensors. |
| 192 | """ |
| 193 | |
| 194 | def __init__(self, ds, queue=None): |
| 195 | """ |
| 196 | Args: |
| 197 | ds(DataFlow): the input DataFlow. |
| 198 | queue (tf.QueueBase): A :class:`tf.QueueBase` whose type |
| 199 | should match the corresponding input signature of the model. |
| 200 | Defaults to a FIFO queue of size 50. |
| 201 | """ |
| 202 | if not isinstance(ds, DataFlow): |
| 203 | raise ValueError("QueueInput takes a DataFlow! Got {}".format(ds)) |
| 204 | self.queue = queue |
| 205 | self.ds = ds |
| 206 | self._inf_ds = RepeatedData(ds, -1) |
| 207 | self._started = False |
| 208 | |
| 209 | def _size(self): |
| 210 | return len(self.ds) |
| 211 | |
| 212 | def _setup(self, inputs): |
| 213 | self._input_placehdrs = [build_or_reuse_placeholder(v) for v in inputs] |
| 214 | assert len(self._input_placehdrs) > 0, \ |
| 215 | "QueueInput has to be used with some inputs!" |
| 216 | with self.cached_name_scope(): |
| 217 | if self.queue is None: |
| 218 | self.queue = tfv1.FIFOQueue( |
| 219 | 50, [x.dtype for x in self._input_placehdrs], |
| 220 | name='input_queue') |
| 221 | logger.info("Setting up the queue '{}' for CPU prefetching ...".format(self.queue.name)) |
| 222 | self.thread = EnqueueThread(self.queue, self._inf_ds, self._input_placehdrs) |
| 223 | |
| 224 | self._dequeue_op = self.queue.dequeue(name='dequeue_for_reset') |
| 225 | |
| 226 | def refill_queue(self): |
| 227 | """ |
| 228 | Clear the queue, then call dataflow.__iter__() again and fill into the queue. |
| 229 | """ |
| 230 | self.thread.pause() # pause enqueue |
| 231 | |
| 232 | opt = tfv1.RunOptions() |
| 233 | opt.timeout_in_ms = 2000 # 2s |
| 234 | sess = tfv1.get_default_session() |
| 235 | # dequeue until empty |
| 236 | try: |
| 237 | while True: |
| 238 | sess.run(self._dequeue_op, options=opt) |
| 239 | except tf.errors.DeadlineExceededError: |
| 240 | pass |
| 241 | |
| 242 | # reset dataflow, start thread |
| 243 | self.thread.reinitialize_dataflow() |
| 244 | self.thread.resume() |
| 245 | |
| 246 | def _create_ema_callback(self): |
no outgoing calls
no test coverage detected