Compute the KL divergence KL[q || p] between:: q(x) = N(q_mu, q_sqrt^2) and:: p(x) = N(0, K) if K is not None p(x) = N(0, I) if K is None We assume L multiple independent distributions, given by the columns of q_mu and the first or last dimens
(
q_mu: TensorType, q_sqrt: TensorType, K: TensorType = None, *, K_cholesky: TensorType = None
)
| 57 | "return: []", |
| 58 | ) |
| 59 | def gauss_kl( |
| 60 | q_mu: TensorType, q_sqrt: TensorType, K: TensorType = None, *, K_cholesky: TensorType = None |
| 61 | ) -> tf.Tensor: |
| 62 | """ |
| 63 | Compute the KL divergence KL[q || p] between:: |
| 64 | |
| 65 | q(x) = N(q_mu, q_sqrt^2) |
| 66 | |
| 67 | and:: |
| 68 | |
| 69 | p(x) = N(0, K) if K is not None |
| 70 | p(x) = N(0, I) if K is None |
| 71 | |
| 72 | We assume L multiple independent distributions, given by the columns of |
| 73 | q_mu and the first or last dimension of q_sqrt. Returns the *sum* of the |
| 74 | divergences. |
| 75 | |
| 76 | q_mu is a matrix ([M, L]), each column contains a mean. |
| 77 | |
| 78 | - q_sqrt can be a 3D tensor ([L, M, M]), each matrix within is a lower |
| 79 | triangular square-root matrix of the covariance of q. |
| 80 | - q_sqrt can be a matrix ([M, L]), each column represents the diagonal of a |
| 81 | square-root matrix of the covariance of q. |
| 82 | |
| 83 | K is the covariance of p (positive-definite matrix). The K matrix can be |
| 84 | passed either directly as `K`, or as its Cholesky factor, `K_cholesky`. In |
| 85 | either case, it can be a single matrix [M, M], in which case the sum of the |
| 86 | L KL divergences is computed by broadcasting, or L different covariances |
| 87 | [L, M, M]. |
| 88 | |
| 89 | Note: if no K matrix is given (both `K` and `K_cholesky` are None), |
| 90 | `gauss_kl` computes the KL divergence from p(x) = N(0, I) instead. |
| 91 | """ |
| 92 | |
| 93 | if (K is not None) and (K_cholesky is not None): |
| 94 | raise ValueError( |
| 95 | "Ambiguous arguments: gauss_kl() must only be passed one of `K` or `K_cholesky`." |
| 96 | ) |
| 97 | |
| 98 | is_white = (K is None) and (K_cholesky is None) |
| 99 | is_diag = len(q_sqrt.shape) == 2 |
| 100 | |
| 101 | M, L = tf.shape(q_mu)[0], tf.shape(q_mu)[1] |
| 102 | |
| 103 | if is_white: |
| 104 | alpha = q_mu # [M, L] |
| 105 | else: |
| 106 | if K is not None: |
| 107 | Lp = tf.linalg.cholesky(K) # [L, M, M] or [M, M] |
| 108 | elif K_cholesky is not None: |
| 109 | Lp = K_cholesky # [L, M, M] or [M, M] |
| 110 | |
| 111 | is_batched = len(Lp.shape) == 3 |
| 112 | |
| 113 | q_mu = tf.transpose(q_mu)[:, :, None] if is_batched else q_mu # [L, M, 1] or [M, L] |
| 114 | alpha = tf.linalg.triangular_solve(Lp, q_mu, lower=True) # [L, M, 1] or [M, L] |
| 115 | |
| 116 | if is_diag: |
searching dependent graphs…