KL divergence between normal distributions parameterized by mean and log-variance.
(mean1, logvar1, mean2, logvar2)
| 6 | |
| 7 | |
| 8 | def normal_kl(mean1, logvar1, mean2, logvar2): |
| 9 | """ |
| 10 | KL divergence between normal distributions parameterized by mean and log-variance. |
| 11 | """ |
| 12 | return 0.5 * (-1.0 + logvar2 - logvar1 + tf.exp(logvar1 - logvar2) |
| 13 | + tf.squared_difference(mean1, mean2) * tf.exp(-logvar2)) |
| 14 | |
| 15 | |
| 16 | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): |
no outgoing calls
no test coverage detected