(self, denoise_fn, x_start, x_t, t, *, clip_denoised: bool, return_pred_xstart: bool)
| 255 | # === Log likelihood calculation === |
| 256 | |
| 257 | def _vb_terms_bpd(self, denoise_fn, x_start, x_t, t, *, clip_denoised: bool, return_pred_xstart: bool): |
| 258 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t) |
| 259 | model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( |
| 260 | denoise_fn, x=x_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) |
| 261 | kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) |
| 262 | kl = nn.meanflat(kl) / np.log(2.) |
| 263 | |
| 264 | decoder_nll = -utils.discretized_gaussian_log_likelihood( |
| 265 | x_start, means=model_mean, log_scales=0.5 * model_log_variance) |
| 266 | assert decoder_nll.shape == x_start.shape |
| 267 | decoder_nll = nn.meanflat(decoder_nll) / np.log(2.) |
| 268 | |
| 269 | # At the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) |
| 270 | assert kl.shape == decoder_nll.shape == t.shape == [x_start.shape[0]] |
| 271 | output = tf.where(tf.equal(t, 0), decoder_nll, kl) |
| 272 | return (output, pred_xstart) if return_pred_xstart else output |
| 273 | |
| 274 | def training_losses(self, denoise_fn, x_start, t, noise=None): |
| 275 | """ |
no test coverage detected