Return the next `batch_size` examples from this data set.
(self, batch_size, window)
| 109 | return np.squeeze(sampled_embeddings_array), sampled_captions |
| 110 | |
| 111 | def next_batch(self, batch_size, window): |
| 112 | """Return the next `batch_size` examples from this data set.""" |
| 113 | start = self._index_in_epoch |
| 114 | self._index_in_epoch += batch_size |
| 115 | |
| 116 | if self._index_in_epoch > self._num_examples: |
| 117 | # Finished epoch |
| 118 | self._epochs_completed += 1 |
| 119 | # Shuffle the data |
| 120 | self._perm = np.arange(self._num_examples) |
| 121 | np.random.shuffle(self._perm) |
| 122 | |
| 123 | # Start next epoch |
| 124 | start = 0 |
| 125 | self._index_in_epoch = batch_size |
| 126 | assert batch_size <= self._num_examples |
| 127 | end = self._index_in_epoch |
| 128 | |
| 129 | current_ids = self._perm[start:end] |
| 130 | fake_ids = np.random.randint(self._num_examples, size=batch_size) |
| 131 | collision_flag =\ |
| 132 | (self._class_id[current_ids] == self._class_id[fake_ids]) |
| 133 | fake_ids[collision_flag] =\ |
| 134 | (fake_ids[collision_flag] + |
| 135 | np.random.randint(100, 200)) % self._num_examples |
| 136 | |
| 137 | sampled_images = self._images[current_ids] |
| 138 | sampled_wrong_images = self._images[fake_ids, :, :, :] |
| 139 | sampled_images = sampled_images.astype(np.float32) |
| 140 | sampled_wrong_images = sampled_wrong_images.astype(np.float32) |
| 141 | sampled_images = sampled_images * (2. / 255) - 1. |
| 142 | sampled_wrong_images = sampled_wrong_images * (2. / 255) - 1. |
| 143 | |
| 144 | sampled_images = self.transform(sampled_images) |
| 145 | sampled_wrong_images = self.transform(sampled_wrong_images) |
| 146 | ret_list = [sampled_images, sampled_wrong_images] |
| 147 | |
| 148 | if self._embeddings is not None: |
| 149 | filenames = [self._filenames[i] for i in current_ids] |
| 150 | class_id = [self._class_id[i] for i in current_ids] |
| 151 | sampled_embeddings, sampled_captions = \ |
| 152 | self.sample_embeddings(self._embeddings[current_ids], |
| 153 | filenames, class_id, window) |
| 154 | ret_list.append(sampled_embeddings) |
| 155 | ret_list.append(sampled_captions) |
| 156 | else: |
| 157 | ret_list.append(None) |
| 158 | ret_list.append(None) |
| 159 | |
| 160 | if self._labels is not None: |
| 161 | ret_list.append(self._labels[current_ids]) |
| 162 | else: |
| 163 | ret_list.append(None) |
| 164 | return ret_list |
| 165 | |
| 166 | def next_batch_test(self, batch_size, start, max_captions): |
| 167 | """Return the next `batch_size` examples from this data set.""" |
no test coverage detected