MCPcopy
hub / github.com/Vchitect/Latte / condition_score

Method condition_score

diffusion/gaussian_diffusion.py:362–378  ·  view source on GitHub ↗

Compute what the p_mean_variance output would have been, should the model's score function be conditioned by cond_fn. See condition_mean() for details on cond_fn. Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020)

(self, cond_fn, p_mean_var, x, t, model_kwargs=None)

Source from the content-addressed store, hash-verified

360 return new_mean
361
362 def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
363 """
364 Compute what the p_mean_variance output would have been, should the
365 model's score function be conditioned by cond_fn.
366 See condition_mean() for details on cond_fn.
367 Unlike condition_mean(), this instead uses the conditioning strategy
368 from Song et al (2020).
369 """
370 alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
371
372 eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
373 eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
374
375 out = p_mean_var.copy()
376 out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
377 out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
378 return out
379
380 def p_sample(
381 self,

Callers 2

ddim_sampleMethod · 0.95
ddim_reverse_sampleMethod · 0.95

Calls 4

_extract_into_tensorFunction · 0.85

Tested by

no test coverage detected