Compute the log-likelihood of a Gaussian distribution discretizing to a given image. :param x: the target images. It is assumed that this was uint8 values, rescaled to the range [-1, 1]. :param means: the Gaussian mean Tensor. :param log_scales: the Gaussian log s
(x, *, means, log_scales)
| 48 | |
| 49 | |
| 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): |
| 51 | """ |
| 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a |
| 53 | given image. |
| 54 | |
| 55 | :param x: the target images. It is assumed that this was uint8 values, |
| 56 | rescaled to the range [-1, 1]. |
| 57 | :param means: the Gaussian mean Tensor. |
| 58 | :param log_scales: the Gaussian log stddev Tensor. |
| 59 | :return: a tensor like x of log probabilities (in nats). |
| 60 | """ |
| 61 | assert x.shape == means.shape == log_scales.shape |
| 62 | centered_x = x - means |
| 63 | inv_stdv = th.exp(-log_scales) |
| 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) |
| 65 | cdf_plus = approx_standard_normal_cdf(plus_in) |
| 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) |
| 67 | cdf_min = approx_standard_normal_cdf(min_in) |
| 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) |
| 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) |
| 70 | cdf_delta = cdf_plus - cdf_min |
| 71 | log_probs = th.where( |
| 72 | x < -0.999, |
| 73 | log_cdf_plus, |
| 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), |
| 75 | ) |
| 76 | assert log_probs.shape == x.shape |
| 77 | return log_probs |
no test coverage detected