MCPcopy
hub / github.com/hojonathanho/diffusion / _extract

Method _extract

diffusion_tf/diffusion_utils_2.py:84–93  ·  view source on GitHub ↗

Extract some coefficients at specified timesteps, then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.

(a, t, x_shape)

Source from the content-addressed store, hash-verified

82
83 @staticmethod
84 def _extract(a, t, x_shape):
85 """
86 Extract some coefficients at specified timesteps,
87 then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
88 """
89 bs, = t.shape
90 assert x_shape[0] == bs
91 out = tf.gather(tf.convert_to_tensor(a, dtype=tf.float32), t)
92 assert out.shape == [bs]
93 return tf.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
94
95 def q_mean_variance(self, x_start, t):
96 mean = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start

Callers 6

q_mean_varianceMethod · 0.95
q_sampleMethod · 0.95
p_mean_varianceMethod · 0.95

Calls

no outgoing calls

Tested by

no test coverage detected