MCPcopy
hub / github.com/hojonathanho/diffusion / InceptionFeatures

Class InceptionFeatures

diffusion_tf/tpu_utils/tpu_utils.py:209–252  ·  view source on GitHub ↗

Compute and store Inception features for a dataset

Source from the content-addressed store, hash-verified

207
208
209class InceptionFeatures:
210 """
211 Compute and store Inception features for a dataset
212 """
213
214 def __init__(self, dataset, strategy, limit_dataset_size=0):
215 # distributed dataset iterator
216 if limit_dataset_size > 0:
217 dataset = dataset.take(limit_dataset_size)
218 self.ds_iterator = strategy.experimental_distribute_dataset(dataset).make_initializable_iterator()
219
220 # inception network on the dataset
221 self.inception_real = distributed(
222 lambda x_: run_inception(tfgan.eval.preprocess_image(x_['image'])),
223 args=(next(self.ds_iterator),), reduction='concat', strategy=strategy)
224
225 self.cached_inception_real = None # cached inception features
226 self.real_inception_score = None # saved inception scores for the dataset
227
228 def get(self, sess):
229 # On the first invocation, compute Inception activations for the eval dataset
230 if self.cached_inception_real is None:
231 print('computing inception features on the eval set...')
232 sess.run(self.ds_iterator.initializer) # reset the eval dataset iterator
233 inception_real_batches, tstart = [], time.time()
234 while True:
235 try:
236 inception_real_batches.append(sess.run(self.inception_real))
237 except tf.errors.OutOfRangeError:
238 break
239 self.cached_inception_real = {
240 feat_key: np.concatenate([batch[feat_key] for batch in inception_real_batches], axis=0).astype(np.float64)
241 for feat_key in ['pool_3', 'logits']
242 }
243 print('cached eval inception tensors: logits: {}, pool_3: {} (time: {})'.format(
244 self.cached_inception_real['logits'].shape, self.cached_inception_real['pool_3'].shape,
245 time.time() - tstart))
246
247 self.real_inception_score = float(
248 classifier_metrics_numpy.classifier_score_from_logits(self.cached_inception_real['logits']))
249 del self.cached_inception_real['logits'] # save memory
250 print('real inception score', self.real_inception_score)
251
252 return self.cached_inception_real, self.real_inception_score
253
254
255class EvalWorker:

Callers 1

__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected