MCPcopy
hub / github.com/GPflow/GPflow / gauss_kl

Function gauss_kl

gpflow/kullback_leiblers.py:59–165  ·  view source on GitHub ↗

Compute the KL divergence KL[q || p] between:: q(x) = N(q_mu, q_sqrt^2) and:: p(x) = N(0, K) if K is not None p(x) = N(0, I) if K is None We assume L multiple independent distributions, given by the columns of q_mu and the first or last dimens

(
    q_mu: TensorType, q_sqrt: TensorType, K: TensorType = None, *, K_cholesky: TensorType = None
)

Source from the content-addressed store, hash-verified

57 "return: []",
58)
59def gauss_kl(
60 q_mu: TensorType, q_sqrt: TensorType, K: TensorType = None, *, K_cholesky: TensorType = None
61) -> tf.Tensor:
62 """
63 Compute the KL divergence KL[q || p] between::
64
65 q(x) = N(q_mu, q_sqrt^2)
66
67 and::
68
69 p(x) = N(0, K) if K is not None
70 p(x) = N(0, I) if K is None
71
72 We assume L multiple independent distributions, given by the columns of
73 q_mu and the first or last dimension of q_sqrt. Returns the *sum* of the
74 divergences.
75
76 q_mu is a matrix ([M, L]), each column contains a mean.
77
78 - q_sqrt can be a 3D tensor ([L, M, M]), each matrix within is a lower
79 triangular square-root matrix of the covariance of q.
80 - q_sqrt can be a matrix ([M, L]), each column represents the diagonal of a
81 square-root matrix of the covariance of q.
82
83 K is the covariance of p (positive-definite matrix). The K matrix can be
84 passed either directly as `K`, or as its Cholesky factor, `K_cholesky`. In
85 either case, it can be a single matrix [M, M], in which case the sum of the
86 L KL divergences is computed by broadcasting, or L different covariances
87 [L, M, M].
88
89 Note: if no K matrix is given (both `K` and `K_cholesky` are None),
90 `gauss_kl` computes the KL divergence from p(x) = N(0, I) instead.
91 """
92
93 if (K is not None) and (K_cholesky is not None):
94 raise ValueError(
95 "Ambiguous arguments: gauss_kl() must only be passed one of `K` or `K_cholesky`."
96 )
97
98 is_white = (K is None) and (K_cholesky is None)
99 is_diag = len(q_sqrt.shape) == 2
100
101 M, L = tf.shape(q_mu)[0], tf.shape(q_mu)[1]
102
103 if is_white:
104 alpha = q_mu # [M, L]
105 else:
106 if K is not None:
107 Lp = tf.linalg.cholesky(K) # [L, M, M] or [M, M]
108 elif K_cholesky is not None:
109 Lp = K_cholesky # [L, M, M] or [M, M]
110
111 is_batched = len(Lp.shape) == 3
112
113 q_mu = tf.transpose(q_mu)[:, :, None] if is_batched else q_mu # [L, M, 1] or [M, L]
114 alpha = tf.linalg.triangular_solve(Lp, q_mu, lower=True) # [L, M, 1] or [M, L]
115
116 if is_diag:

Callers 9

test_kl_k_choleskyFunction · 0.90
test_diagsFunction · 0.90
test_whitenedFunction · 0.90
test_onedFunction · 0.90
test_unknown_size_inputsFunction · 0.90
_Function · 0.85
elboMethod · 0.85

Calls 3

to_default_floatFunction · 0.85
default_floatFunction · 0.85
shapeMethod · 0.45

Tested by 7

test_kl_k_choleskyFunction · 0.72
test_diagsFunction · 0.72
test_whitenedFunction · 0.72
test_onedFunction · 0.72
test_unknown_size_inputsFunction · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…