MCPcopy
hub / github.com/GPflow/GPflow / independent_interdomain_conditional

Function independent_interdomain_conditional

gpflow/conditionals/util.py:257–329  ·  view source on GitHub ↗

The inducing outputs live in the g-space (R^L). Interdomain conditional calculation. :param full_cov: calculate covariance between inputs :param full_output_cov: calculate covariance between outputs :param white: use whitened representation :return: mean, variance

(
    Kmn: tf.Tensor,
    Kmm: tf.Tensor,
    Knn: tf.Tensor,
    f: tf.Tensor,
    *,
    full_cov: bool = False,
    full_output_cov: bool = False,
    q_sqrt: Optional[tf.Tensor] = None,
    white: bool = False,
)

Source from the content-addressed store, hash-verified

255 "return[1]: [N, P, N, P] if full_cov and full_output_cov",
256)
257def independent_interdomain_conditional(
258 Kmn: tf.Tensor,
259 Kmm: tf.Tensor,
260 Knn: tf.Tensor,
261 f: tf.Tensor,
262 *,
263 full_cov: bool = False,
264 full_output_cov: bool = False,
265 q_sqrt: Optional[tf.Tensor] = None,
266 white: bool = False,
267) -> MeanAndVariance:
268 """
269 The inducing outputs live in the g-space (R^L).
270
271 Interdomain conditional calculation.
272
273 :param full_cov: calculate covariance between inputs
274 :param full_output_cov: calculate covariance between outputs
275 :param white: use whitened representation
276 :return: mean, variance
277 """
278 M, L, N, P = tf.unstack(tf.shape(Kmn), num=Kmn.shape.ndims, axis=0)
279
280 Lm = tf.linalg.cholesky(Kmm) # [L, M, M]
281
282 # Compute the projection matrix A
283 Kmn = tf.reshape(tf.transpose(Kmn, (1, 0, 2, 3)), (L, M, N * P))
284 A = tf.linalg.triangular_solve(Lm, Kmn, lower=True) # [L, M, M] \ [L, M, N*P] -> [L, M, N*P]
285 Ar = tf.reshape(A, (L, M, N, P))
286
287 # compute the covariance due to the conditioning
288 if full_cov and full_output_cov:
289 fvar = Knn - tf.tensordot(Ar, Ar, [[0, 1], [0, 1]]) # [N, P, N, P]
290 elif full_cov and not full_output_cov:
291 At = tf.reshape(tf.transpose(Ar), (P, N, M * L)) # [P, N, L]
292 fvar = Knn - tf.linalg.matmul(At, At, transpose_b=True) # [P, N, N]
293 elif not full_cov and full_output_cov:
294 At = tf.reshape(tf.transpose(Ar, [2, 3, 1, 0]), (N, P, M * L)) # [N, P, L]
295 fvar = Knn - tf.linalg.matmul(At, At, transpose_b=True) # [N, P, P]
296 elif not full_cov and not full_output_cov:
297 fvar = Knn - tf.reshape(tf.reduce_sum(tf.square(A), [0, 1]), (N, P)) # Knn: [N, P]
298
299 # another backsubstitution in the unwhitened case
300 if not white:
301 A = tf.linalg.triangular_solve(
302 Lm, A, adjoint=True
303 ) # [L, M, M] \ [L, M, N*P] -> [L, M, N*P]
304 Ar = tf.reshape(A, (L, M, N, P))
305
306 fmean = tf.tensordot(Ar, f, [[1, 0], [0, 1]]) # [N, P]
307
308 if q_sqrt is not None:
309 if q_sqrt.shape.ndims == 3:
310 Lf = tf.linalg.band_part(q_sqrt, -1, 0) # [L, M, M]
311 LTA = tf.linalg.matmul(
312 Lf, A, transpose_a=True
313 ) # [L, M, M] * [L, M, P] -> [L, M, P]
314 else: # q_sqrt [M, L]

Calls 1

shapeMethod · 0.45

Used in the wild real call sites across dependent graphs

searching dependent graphs…