MCPcopy Index your code
hub / github.com/hojonathanho/diffusion / _vb_terms_bpd

Method _vb_terms_bpd

diffusion_tf/diffusion_utils_2.py:257–272  ·  view source on GitHub ↗
(self, denoise_fn, x_start, x_t, t, *, clip_denoised: bool, return_pred_xstart: bool)

Source from the content-addressed store, hash-verified

255 # === Log likelihood calculation ===
256
257 def _vb_terms_bpd(self, denoise_fn, x_start, x_t, t, *, clip_denoised: bool, return_pred_xstart: bool):
258 true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)
259 model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
260 denoise_fn, x=x_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
261 kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
262 kl = nn.meanflat(kl) / np.log(2.)
263
264 decoder_nll = -utils.discretized_gaussian_log_likelihood(
265 x_start, means=model_mean, log_scales=0.5 * model_log_variance)
266 assert decoder_nll.shape == x_start.shape
267 decoder_nll = nn.meanflat(decoder_nll) / np.log(2.)
268
269 # At the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
270 assert kl.shape == decoder_nll.shape == t.shape == [x_start.shape[0]]
271 output = tf.where(tf.equal(t, 0), decoder_nll, kl)
272 return (output, pred_xstart) if return_pred_xstart else output
273
274 def training_losses(self, denoise_fn, x_start, t, noise=None):
275 """

Callers 2

training_lossesMethod · 0.95
_loop_bodyMethod · 0.95

Calls 3

p_mean_varianceMethod · 0.95
normal_klFunction · 0.70

Tested by

no test coverage detected