Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0)
(self, x_start, x_t, t)
| 230 | ) |
| 231 | |
| 232 | def q_posterior_mean_variance(self, x_start, x_t, t): |
| 233 | """ |
| 234 | Compute the mean and variance of the diffusion posterior: |
| 235 | q(x_{t-1} | x_t, x_0) |
| 236 | """ |
| 237 | assert x_start.shape == x_t.shape |
| 238 | posterior_mean = ( |
| 239 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start |
| 240 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t |
| 241 | ) |
| 242 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) |
| 243 | posterior_log_variance_clipped = _extract_into_tensor( |
| 244 | self.posterior_log_variance_clipped, t, x_t.shape |
| 245 | ) |
| 246 | assert ( |
| 247 | posterior_mean.shape[0] |
| 248 | == posterior_variance.shape[0] |
| 249 | == posterior_log_variance_clipped.shape[0] |
| 250 | == x_start.shape[0] |
| 251 | ) |
| 252 | return posterior_mean, posterior_variance, posterior_log_variance_clipped |
| 253 | |
| 254 | def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): |
| 255 | """ |
no test coverage detected