Input by iterating over a DataFlow and feed datapoints. Note: If `get_input_tensors()` is called more than one time, it will return the same placeholders (i.e. feed points) as the first time. Therefore you can't use it for data-parallel training.
| 65 | |
| 66 | |
| 67 | class FeedInput(InputSource): |
| 68 | """ |
| 69 | Input by iterating over a DataFlow and feed datapoints. |
| 70 | |
| 71 | Note: |
| 72 | If `get_input_tensors()` is called more than one time, it will return the same placeholders (i.e. feed points) |
| 73 | as the first time. |
| 74 | Therefore you can't use it for data-parallel training. |
| 75 | """ |
| 76 | |
| 77 | class _FeedCallback(Callback): |
| 78 | def __init__(self, ds, placeholders): |
| 79 | self._ds = ds |
| 80 | self._itr = self._ds.__iter__() |
| 81 | self._placeholders = placeholders |
| 82 | |
| 83 | def _before_run(self, _): |
| 84 | dp = next(self._itr) |
| 85 | assert len(dp) == len(self._placeholders), "[FeedInput] datapoints and inputs are of different length!" |
| 86 | feed = _make_feeds(self._placeholders, dp) |
| 87 | return tfv1.train.SessionRunArgs(fetches=[], feed_dict=feed) |
| 88 | |
| 89 | def _reset(self): |
| 90 | self._itr = self._ds.__iter__() |
| 91 | |
| 92 | def __init__(self, ds, infinite=True): |
| 93 | """ |
| 94 | Args: |
| 95 | ds (DataFlow): the input DataFlow. |
| 96 | infinite (bool): When set to False, will raise StopIteration when |
| 97 | ds is exhausted. |
| 98 | """ |
| 99 | if not isinstance(ds, DataFlow): |
| 100 | raise ValueError("FeedInput takes a DataFlow! Got {}".format(ds)) |
| 101 | self.ds = ds |
| 102 | if infinite: |
| 103 | self._iter_ds = RepeatedData(self.ds, -1) |
| 104 | else: |
| 105 | self._iter_ds = self.ds |
| 106 | |
| 107 | def _size(self): |
| 108 | return len(self.ds) |
| 109 | |
| 110 | def _setup(self, inputs): |
| 111 | # placeholders as input are always safe to reuse. |
| 112 | self._all_placehdrs = [build_or_reuse_placeholder(v) for v in inputs] |
| 113 | self._cb = self._FeedCallback(self._iter_ds, self._all_placehdrs) |
| 114 | |
| 115 | def _get_input_tensors(self): |
| 116 | return self._all_placehdrs |
| 117 | |
| 118 | def _reset_state(self): |
| 119 | self._cb._reset() |
| 120 | |
| 121 | def _get_callbacks(self): |
| 122 | return [self._cb, _get_reset_callback(self._iter_ds)] |
| 123 | |
| 124 |
no outgoing calls
no test coverage detected