MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / FeedInput

Class FeedInput

tensorpack/input_source/input_source.py:67–122  ·  view source on GitHub ↗

Input by iterating over a DataFlow and feed datapoints. Note: If `get_input_tensors()` is called more than one time, it will return the same placeholders (i.e. feed points) as the first time. Therefore you can't use it for data-parallel training.

Source from the content-addressed store, hash-verified

65
66
67class FeedInput(InputSource):
68 """
69 Input by iterating over a DataFlow and feed datapoints.
70
71 Note:
72 If `get_input_tensors()` is called more than one time, it will return the same placeholders (i.e. feed points)
73 as the first time.
74 Therefore you can't use it for data-parallel training.
75 """
76
77 class _FeedCallback(Callback):
78 def __init__(self, ds, placeholders):
79 self._ds = ds
80 self._itr = self._ds.__iter__()
81 self._placeholders = placeholders
82
83 def _before_run(self, _):
84 dp = next(self._itr)
85 assert len(dp) == len(self._placeholders), "[FeedInput] datapoints and inputs are of different length!"
86 feed = _make_feeds(self._placeholders, dp)
87 return tfv1.train.SessionRunArgs(fetches=[], feed_dict=feed)
88
89 def _reset(self):
90 self._itr = self._ds.__iter__()
91
92 def __init__(self, ds, infinite=True):
93 """
94 Args:
95 ds (DataFlow): the input DataFlow.
96 infinite (bool): When set to False, will raise StopIteration when
97 ds is exhausted.
98 """
99 if not isinstance(ds, DataFlow):
100 raise ValueError("FeedInput takes a DataFlow! Got {}".format(ds))
101 self.ds = ds
102 if infinite:
103 self._iter_ds = RepeatedData(self.ds, -1)
104 else:
105 self._iter_ds = self.ds
106
107 def _size(self):
108 return len(self.ds)
109
110 def _setup(self, inputs):
111 # placeholders as input are always safe to reuse.
112 self._all_placehdrs = [build_or_reuse_placeholder(v) for v in inputs]
113 self._cb = self._FeedCallback(self._iter_ds, self._all_placehdrs)
114
115 def _get_input_tensors(self):
116 return self._all_placehdrs
117
118 def _reset_state(self):
119 self._cb._reset()
120
121 def _get_callbacks(self):
122 return [self._cb, _get_reset_callback(self._iter_ds)]
123
124

Callers 4

__init__Method · 0.85
apply_default_prefetchFunction · 0.85
mnist-tflayers.pyFile · 0.85
mnist-convnet.pyFile · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected