MCPcopy
hub / github.com/hanzhanggit/StackGAN / next_batch

Method next_batch

misc/datasets.py:111–164  ·  view source on GitHub ↗

Return the next `batch_size` examples from this data set.

(self, batch_size, window)

Source from the content-addressed store, hash-verified

109 return np.squeeze(sampled_embeddings_array), sampled_captions
110
111 def next_batch(self, batch_size, window):
112 """Return the next `batch_size` examples from this data set."""
113 start = self._index_in_epoch
114 self._index_in_epoch += batch_size
115
116 if self._index_in_epoch > self._num_examples:
117 # Finished epoch
118 self._epochs_completed += 1
119 # Shuffle the data
120 self._perm = np.arange(self._num_examples)
121 np.random.shuffle(self._perm)
122
123 # Start next epoch
124 start = 0
125 self._index_in_epoch = batch_size
126 assert batch_size <= self._num_examples
127 end = self._index_in_epoch
128
129 current_ids = self._perm[start:end]
130 fake_ids = np.random.randint(self._num_examples, size=batch_size)
131 collision_flag =\
132 (self._class_id[current_ids] == self._class_id[fake_ids])
133 fake_ids[collision_flag] =\
134 (fake_ids[collision_flag] +
135 np.random.randint(100, 200)) % self._num_examples
136
137 sampled_images = self._images[current_ids]
138 sampled_wrong_images = self._images[fake_ids, :, :, :]
139 sampled_images = sampled_images.astype(np.float32)
140 sampled_wrong_images = sampled_wrong_images.astype(np.float32)
141 sampled_images = sampled_images * (2. / 255) - 1.
142 sampled_wrong_images = sampled_wrong_images * (2. / 255) - 1.
143
144 sampled_images = self.transform(sampled_images)
145 sampled_wrong_images = self.transform(sampled_wrong_images)
146 ret_list = [sampled_images, sampled_wrong_images]
147
148 if self._embeddings is not None:
149 filenames = [self._filenames[i] for i in current_ids]
150 class_id = [self._class_id[i] for i in current_ids]
151 sampled_embeddings, sampled_captions = \
152 self.sample_embeddings(self._embeddings[current_ids],
153 filenames, class_id, window)
154 ret_list.append(sampled_embeddings)
155 ret_list.append(sampled_captions)
156 else:
157 ret_list.append(None)
158 ret_list.append(None)
159
160 if self._labels is not None:
161 ret_list.append(self._labels[current_ids])
162 else:
163 ret_list.append(None)
164 return ret_list
165
166 def next_batch_test(self, batch_size, start, max_captions):
167 """Return the next `batch_size` examples from this data set."""

Callers 4

epoch_sum_imagesMethod · 0.80
train_one_stepMethod · 0.80
epoch_sum_imagesMethod · 0.80
trainMethod · 0.80

Calls 2

transformMethod · 0.95
sample_embeddingsMethod · 0.95

Tested by

no test coverage detected