This function handles conditioning of multi-output GPs in the case where the conditioning points are all fully correlated, in both the prior and posterior. Note: This conditional can handle 'repetitions' R, given in `f` and `q_sqrt`. :param full_cov: calculate covariance between in
(
Kmn: tf.Tensor,
Kmm: tf.Tensor,
Knn: tf.Tensor,
f: tf.Tensor,
*,
full_cov: bool = False,
full_output_cov: bool = False,
q_sqrt: Optional[tf.Tensor] = None,
white: bool = False,
)
| 393 | "return[1]: [R, N, P, N, P] if full_cov and full_output_cov", |
| 394 | ) |
| 395 | def fully_correlated_conditional_repeat( |
| 396 | Kmn: tf.Tensor, |
| 397 | Kmm: tf.Tensor, |
| 398 | Knn: tf.Tensor, |
| 399 | f: tf.Tensor, |
| 400 | *, |
| 401 | full_cov: bool = False, |
| 402 | full_output_cov: bool = False, |
| 403 | q_sqrt: Optional[tf.Tensor] = None, |
| 404 | white: bool = False, |
| 405 | ) -> MeanAndVariance: |
| 406 | """ |
| 407 | This function handles conditioning of multi-output GPs in the case where the conditioning |
| 408 | points are all fully correlated, in both the prior and posterior. |
| 409 | Note: This conditional can handle 'repetitions' R, given in `f` and `q_sqrt`. |
| 410 | |
| 411 | :param full_cov: calculate covariance between inputs |
| 412 | :param full_output_cov: calculate covariance between outputs |
| 413 | :param white: use whitened representation |
| 414 | :return: mean, variance |
| 415 | """ |
| 416 | R = tf.shape(f)[1] |
| 417 | M, N, P = tf.unstack(tf.shape(Kmn), num=Kmn.shape.ndims, axis=0) |
| 418 | |
| 419 | Lm = tf.linalg.cholesky(Kmm) |
| 420 | |
| 421 | # Compute the projection matrix A |
| 422 | # Lm: [M, M] Kmn: [M, P] |
| 423 | Kmn = tf.reshape(Kmn, (M, N * P)) # [M, P] |
| 424 | A = tf.linalg.triangular_solve(Lm, Kmn, lower=True) # [M, P] |
| 425 | Ar = tf.reshape(A, (M, N, P)) |
| 426 | |
| 427 | # compute the covariance due to the conditioning |
| 428 | if full_cov and full_output_cov: |
| 429 | # fvar = Knn - tf.linalg.matmul(Ar, Ar, transpose_a=True) # [P, P], then reshape? |
| 430 | fvar = Knn - tf.tensordot(Ar, Ar, [[0], [0]]) # [N, P, N, P] |
| 431 | elif full_cov and not full_output_cov: |
| 432 | At = tf.transpose(Ar) # [P, N, M] |
| 433 | fvar = Knn - tf.linalg.matmul(At, At, transpose_b=True) # [P, N, N] |
| 434 | elif not full_cov and full_output_cov: |
| 435 | # This transpose is annoying |
| 436 | At = tf.transpose(Ar, [1, 0, 2]) # [N, M, P] |
| 437 | # fvar = Knn - tf.einsum('mnk,mnl->nkl', Ar, Ar) |
| 438 | fvar = Knn - tf.linalg.matmul(At, At, transpose_a=True) # [N, P, P] |
| 439 | elif not full_cov and not full_output_cov: |
| 440 | # Knn: [N, P] |
| 441 | # Can also do this with a matmul |
| 442 | fvar = Knn - tf.reshape(tf.reduce_sum(tf.square(A), [0]), (N, P)) |
| 443 | |
| 444 | # another backsubstitution in the unwhitened case |
| 445 | if not white: |
| 446 | A = tf.linalg.triangular_solve(Lm, A, adjoint=True) # [M, P] |
| 447 | |
| 448 | # f: [M, R] |
| 449 | fmean = tf.linalg.matmul(f, A, transpose_a=True) # [R, M] * [M, P] -> [R, P] |
| 450 | fmean = tf.reshape(fmean, (R, N, P)) # [R, N, P] |
| 451 | |
| 452 | if q_sqrt is not None: |
searching dependent graphs…