MCPcopy
hub / github.com/hojonathanho/diffusion / _make_sampling_graph

Method _make_sampling_graph

diffusion_tf/tpu_utils/tpu_utils.py:310–331  ·  view source on GitHub ↗
(self, img_shape, with_inception)

Source from the content-addressed store, hash-verified

308 img_shape=img_batch_shape[1:], with_inception=True)
309
310 def _make_sampling_graph(self, img_shape, with_inception):
311
312 def _make_inputs(total_bs, local_bs):
313 # Dummy inputs to feed to samplers
314 input_x = tf.fill([local_bs, *img_shape], value=np.nan)
315 input_y = tf.random_uniform([local_bs], 0, self.dataset.num_classes, dtype=tf.int32)
316 return input_x, input_y
317
318 # Samples
319 samples_outputs = distributed(
320 self.model.samples_fn,
321 args=_make_inputs(self.total_bs, self.local_bs),
322 reduction='concat', strategy=self.strategy)
323 if not with_inception:
324 return samples_outputs
325
326 # Inception activations of samples
327 samples_inception = distributed(
328 self.model.sample_and_run_inception,
329 args=_make_inputs(self.inception_bs, self.inception_local_bs),
330 reduction='concat', strategy=self.strategy)
331 return samples_outputs, samples_inception
332
333 def _run_sampling(self, sess, ema: bool):
334 out = {}

Callers 1

__init__Method · 0.95

Calls 1

distributedFunction · 0.85

Tested by

no test coverage detected