Compute what the p_mean_variance output would have been, should the model's score function be conditioned by cond_fn. See condition_mean() for details on cond_fn. Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020)
(self, cond_fn, p_mean_var, x, t, model_kwargs=None)
| 360 | return new_mean |
| 361 | |
| 362 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): |
| 363 | """ |
| 364 | Compute what the p_mean_variance output would have been, should the |
| 365 | model's score function be conditioned by cond_fn. |
| 366 | See condition_mean() for details on cond_fn. |
| 367 | Unlike condition_mean(), this instead uses the conditioning strategy |
| 368 | from Song et al (2020). |
| 369 | """ |
| 370 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) |
| 371 | |
| 372 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) |
| 373 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) |
| 374 | |
| 375 | out = p_mean_var.copy() |
| 376 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) |
| 377 | out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) |
| 378 | return out |
| 379 | |
| 380 | def p_sample( |
| 381 | self, |
no test coverage detected