| 308 | img_shape=img_batch_shape[1:], with_inception=True) |
| 309 | |
| 310 | def _make_sampling_graph(self, img_shape, with_inception): |
| 311 | |
| 312 | def _make_inputs(total_bs, local_bs): |
| 313 | # Dummy inputs to feed to samplers |
| 314 | input_x = tf.fill([local_bs, *img_shape], value=np.nan) |
| 315 | input_y = tf.random_uniform([local_bs], 0, self.dataset.num_classes, dtype=tf.int32) |
| 316 | return input_x, input_y |
| 317 | |
| 318 | # Samples |
| 319 | samples_outputs = distributed( |
| 320 | self.model.samples_fn, |
| 321 | args=_make_inputs(self.total_bs, self.local_bs), |
| 322 | reduction='concat', strategy=self.strategy) |
| 323 | if not with_inception: |
| 324 | return samples_outputs |
| 325 | |
| 326 | # Inception activations of samples |
| 327 | samples_inception = distributed( |
| 328 | self.model.sample_and_run_inception, |
| 329 | args=_make_inputs(self.inception_bs, self.inception_local_bs), |
| 330 | reduction='concat', strategy=self.strategy) |
| 331 | return samples_outputs, samples_inception |
| 332 | |
| 333 | def _run_sampling(self, sess, ema: bool): |
| 334 | out = {} |