MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / BatchData

Class BatchData

tensorpack/dataflow/common.py:71–190  ·  view source on GitHub ↗

Stack datapoints into batches. It produces datapoints of the same number of components as ``ds``, but each component has one new extra dimension of size ``batch_size``. The batch can be either a list of original components, or (by default) a numpy array of original components.

Source from the content-addressed store, hash-verified

69
70
71class BatchData(ProxyDataFlow):
72 """
73 Stack datapoints into batches.
74 It produces datapoints of the same number of components as ``ds``, but
75 each component has one new extra dimension of size ``batch_size``.
76 The batch can be either a list of original components, or (by default)
77 a numpy array of original components.
78 """
79
80 def __init__(self, ds, batch_size, remainder=False, use_list=False):
81 """
82 Args:
83 ds (DataFlow): A dataflow that produces either list or dict.
84 When ``use_list=False``, the components of ``ds``
85 must be either scalars or :class:`np.ndarray`, and have to be consistent in shapes.
86 batch_size(int): batch size
87 remainder (bool): When the remaining datapoints in ``ds`` is not
88 enough to form a batch, whether or not to also produce the remaining
89 data as a smaller batch.
90 If set to False, all produced datapoints are guaranteed to have the same batch size.
91 If set to True, `len(ds)` must be accurate.
92 use_list (bool): if True, each component will contain a list
93 of datapoints instead of an numpy array of an extra dimension.
94 """
95 super(BatchData, self).__init__(ds)
96 if not remainder:
97 try:
98 assert batch_size <= len(ds)
99 except NotImplementedError:
100 pass
101 self.batch_size = int(batch_size)
102 assert self.batch_size > 0
103 self.remainder = remainder
104 self.use_list = use_list
105
106 def __len__(self):
107 ds_size = len(self.ds)
108 div = ds_size // self.batch_size
109 rem = ds_size % self.batch_size
110 if rem == 0:
111 return div
112 return div + int(self.remainder)
113
114 def __iter__(self):
115 """
116 Yields:
117 Batched data by stacking each component on an extra 0th dimension.
118 """
119 holder = []
120 for data in self.ds:
121 holder.append(data)
122 if len(holder) == self.batch_size:
123 yield BatchData.aggregate_batch(holder, self.use_list)
124 del holder[:]
125 if self.remainder and len(holder) > 0:
126 yield BatchData.aggregate_batch(holder, self.use_list)
127
128 @staticmethod

Callers 15

get_imagenet_dataflowFunction · 0.90
get_imagenet_dataflowFunction · 0.90
get_imagenet_dataflowFunction · 0.90
get_test_dataFunction · 0.90
get_imagenet_dataflowFunction · 0.90
get_dataFunction · 0.90
get_dataFunction · 0.85
get_celebA_dataFunction · 0.85
get_dataFunction · 0.85
sampleFunction · 0.85
get_dataFunction · 0.85
get_dataFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected