Extract some coefficients at specified timesteps, then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
(a, t, x_shape)
| 82 | |
| 83 | @staticmethod |
| 84 | def _extract(a, t, x_shape): |
| 85 | """ |
| 86 | Extract some coefficients at specified timesteps, |
| 87 | then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. |
| 88 | """ |
| 89 | bs, = t.shape |
| 90 | assert x_shape[0] == bs |
| 91 | out = tf.gather(tf.convert_to_tensor(a, dtype=tf.float32), t) |
| 92 | assert out.shape == [bs] |
| 93 | return tf.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) |
| 94 | |
| 95 | def q_mean_variance(self, x_start, t): |
| 96 | mean = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start |
no outgoing calls
no test coverage detected