Compute and store Inception features for a dataset
| 207 | |
| 208 | |
| 209 | class InceptionFeatures: |
| 210 | """ |
| 211 | Compute and store Inception features for a dataset |
| 212 | """ |
| 213 | |
| 214 | def __init__(self, dataset, strategy, limit_dataset_size=0): |
| 215 | # distributed dataset iterator |
| 216 | if limit_dataset_size > 0: |
| 217 | dataset = dataset.take(limit_dataset_size) |
| 218 | self.ds_iterator = strategy.experimental_distribute_dataset(dataset).make_initializable_iterator() |
| 219 | |
| 220 | # inception network on the dataset |
| 221 | self.inception_real = distributed( |
| 222 | lambda x_: run_inception(tfgan.eval.preprocess_image(x_['image'])), |
| 223 | args=(next(self.ds_iterator),), reduction='concat', strategy=strategy) |
| 224 | |
| 225 | self.cached_inception_real = None # cached inception features |
| 226 | self.real_inception_score = None # saved inception scores for the dataset |
| 227 | |
| 228 | def get(self, sess): |
| 229 | # On the first invocation, compute Inception activations for the eval dataset |
| 230 | if self.cached_inception_real is None: |
| 231 | print('computing inception features on the eval set...') |
| 232 | sess.run(self.ds_iterator.initializer) # reset the eval dataset iterator |
| 233 | inception_real_batches, tstart = [], time.time() |
| 234 | while True: |
| 235 | try: |
| 236 | inception_real_batches.append(sess.run(self.inception_real)) |
| 237 | except tf.errors.OutOfRangeError: |
| 238 | break |
| 239 | self.cached_inception_real = { |
| 240 | feat_key: np.concatenate([batch[feat_key] for batch in inception_real_batches], axis=0).astype(np.float64) |
| 241 | for feat_key in ['pool_3', 'logits'] |
| 242 | } |
| 243 | print('cached eval inception tensors: logits: {}, pool_3: {} (time: {})'.format( |
| 244 | self.cached_inception_real['logits'].shape, self.cached_inception_real['pool_3'].shape, |
| 245 | time.time() - tstart)) |
| 246 | |
| 247 | self.real_inception_score = float( |
| 248 | classifier_metrics_numpy.classifier_score_from_logits(self.cached_inception_real['logits'])) |
| 249 | del self.cached_inception_real['logits'] # save memory |
| 250 | print('real inception score', self.real_inception_score) |
| 251 | |
| 252 | return self.cached_inception_real, self.real_inception_score |
| 253 | |
| 254 | |
| 255 | class EvalWorker: |