Stack datapoints into batches. It produces datapoints of the same number of components as ``ds``, but each component has one new extra dimension of size ``batch_size``. The batch can be either a list of original components, or (by default) a numpy array of original components.
| 69 | |
| 70 | |
| 71 | class BatchData(ProxyDataFlow): |
| 72 | """ |
| 73 | Stack datapoints into batches. |
| 74 | It produces datapoints of the same number of components as ``ds``, but |
| 75 | each component has one new extra dimension of size ``batch_size``. |
| 76 | The batch can be either a list of original components, or (by default) |
| 77 | a numpy array of original components. |
| 78 | """ |
| 79 | |
| 80 | def __init__(self, ds, batch_size, remainder=False, use_list=False): |
| 81 | """ |
| 82 | Args: |
| 83 | ds (DataFlow): A dataflow that produces either list or dict. |
| 84 | When ``use_list=False``, the components of ``ds`` |
| 85 | must be either scalars or :class:`np.ndarray`, and have to be consistent in shapes. |
| 86 | batch_size(int): batch size |
| 87 | remainder (bool): When the remaining datapoints in ``ds`` is not |
| 88 | enough to form a batch, whether or not to also produce the remaining |
| 89 | data as a smaller batch. |
| 90 | If set to False, all produced datapoints are guaranteed to have the same batch size. |
| 91 | If set to True, `len(ds)` must be accurate. |
| 92 | use_list (bool): if True, each component will contain a list |
| 93 | of datapoints instead of an numpy array of an extra dimension. |
| 94 | """ |
| 95 | super(BatchData, self).__init__(ds) |
| 96 | if not remainder: |
| 97 | try: |
| 98 | assert batch_size <= len(ds) |
| 99 | except NotImplementedError: |
| 100 | pass |
| 101 | self.batch_size = int(batch_size) |
| 102 | assert self.batch_size > 0 |
| 103 | self.remainder = remainder |
| 104 | self.use_list = use_list |
| 105 | |
| 106 | def __len__(self): |
| 107 | ds_size = len(self.ds) |
| 108 | div = ds_size // self.batch_size |
| 109 | rem = ds_size % self.batch_size |
| 110 | if rem == 0: |
| 111 | return div |
| 112 | return div + int(self.remainder) |
| 113 | |
| 114 | def __iter__(self): |
| 115 | """ |
| 116 | Yields: |
| 117 | Batched data by stacking each component on an extra 0th dimension. |
| 118 | """ |
| 119 | holder = [] |
| 120 | for data in self.ds: |
| 121 | holder.append(data) |
| 122 | if len(holder) == self.batch_size: |
| 123 | yield BatchData.aggregate_batch(holder, self.use_list) |
| 124 | del holder[:] |
| 125 | if self.remainder and len(holder) > 0: |
| 126 | yield BatchData.aggregate_batch(holder, self.use_list) |
| 127 | |
| 128 | @staticmethod |
no outgoing calls
no test coverage detected