MCPcopy
hub / github.com/hanzhanggit/StackGAN / sample_embeddings

Method sample_embeddings

misc/datasets.py:87–109  ·  view source on GitHub ↗
(self, embeddings, filenames, class_id, sample_num)

Source from the content-addressed store, hash-verified

85 return images
86
87 def sample_embeddings(self, embeddings, filenames, class_id, sample_num):
88 if len(embeddings.shape) == 2 or embeddings.shape[1] == 1:
89 return np.squeeze(embeddings)
90 else:
91 batch_size, embedding_num, _ = embeddings.shape
92 # Take every sample_num captions to compute the mean vector
93 sampled_embeddings = []
94 sampled_captions = []
95 for i in range(batch_size):
96 randix = np.random.choice(embedding_num,
97 sample_num, replace=False)
98 if sample_num == 1:
99 randix = int(randix)
100 captions = self.readCaptions(filenames[i],
101 class_id[i])
102 sampled_captions.append(captions[randix])
103 sampled_embeddings.append(embeddings[i, randix, :])
104 else:
105 e_sample = embeddings[i, randix, :]
106 e_mean = np.mean(e_sample, axis=0)
107 sampled_embeddings.append(e_mean)
108 sampled_embeddings_array = np.array(sampled_embeddings)
109 return np.squeeze(sampled_embeddings_array), sampled_captions
110
111 def next_batch(self, batch_size, window):
112 """Return the next `batch_size` examples from this data set."""

Callers 1

next_batchMethod · 0.95

Calls 1

readCaptionsMethod · 0.95

Tested by

no test coverage detected