The inducing outputs live in the g-space (R^L). Interdomain conditional calculation. :param full_cov: calculate covariance between inputs :param full_output_cov: calculate covariance between outputs :param white: use whitened representation :return: mean, variance
(
Kmn: tf.Tensor,
Kmm: tf.Tensor,
Knn: tf.Tensor,
f: tf.Tensor,
*,
full_cov: bool = False,
full_output_cov: bool = False,
q_sqrt: Optional[tf.Tensor] = None,
white: bool = False,
)
| 255 | "return[1]: [N, P, N, P] if full_cov and full_output_cov", |
| 256 | ) |
| 257 | def independent_interdomain_conditional( |
| 258 | Kmn: tf.Tensor, |
| 259 | Kmm: tf.Tensor, |
| 260 | Knn: tf.Tensor, |
| 261 | f: tf.Tensor, |
| 262 | *, |
| 263 | full_cov: bool = False, |
| 264 | full_output_cov: bool = False, |
| 265 | q_sqrt: Optional[tf.Tensor] = None, |
| 266 | white: bool = False, |
| 267 | ) -> MeanAndVariance: |
| 268 | """ |
| 269 | The inducing outputs live in the g-space (R^L). |
| 270 | |
| 271 | Interdomain conditional calculation. |
| 272 | |
| 273 | :param full_cov: calculate covariance between inputs |
| 274 | :param full_output_cov: calculate covariance between outputs |
| 275 | :param white: use whitened representation |
| 276 | :return: mean, variance |
| 277 | """ |
| 278 | M, L, N, P = tf.unstack(tf.shape(Kmn), num=Kmn.shape.ndims, axis=0) |
| 279 | |
| 280 | Lm = tf.linalg.cholesky(Kmm) # [L, M, M] |
| 281 | |
| 282 | # Compute the projection matrix A |
| 283 | Kmn = tf.reshape(tf.transpose(Kmn, (1, 0, 2, 3)), (L, M, N * P)) |
| 284 | A = tf.linalg.triangular_solve(Lm, Kmn, lower=True) # [L, M, M] \ [L, M, N*P] -> [L, M, N*P] |
| 285 | Ar = tf.reshape(A, (L, M, N, P)) |
| 286 | |
| 287 | # compute the covariance due to the conditioning |
| 288 | if full_cov and full_output_cov: |
| 289 | fvar = Knn - tf.tensordot(Ar, Ar, [[0, 1], [0, 1]]) # [N, P, N, P] |
| 290 | elif full_cov and not full_output_cov: |
| 291 | At = tf.reshape(tf.transpose(Ar), (P, N, M * L)) # [P, N, L] |
| 292 | fvar = Knn - tf.linalg.matmul(At, At, transpose_b=True) # [P, N, N] |
| 293 | elif not full_cov and full_output_cov: |
| 294 | At = tf.reshape(tf.transpose(Ar, [2, 3, 1, 0]), (N, P, M * L)) # [N, P, L] |
| 295 | fvar = Knn - tf.linalg.matmul(At, At, transpose_b=True) # [N, P, P] |
| 296 | elif not full_cov and not full_output_cov: |
| 297 | fvar = Knn - tf.reshape(tf.reduce_sum(tf.square(A), [0, 1]), (N, P)) # Knn: [N, P] |
| 298 | |
| 299 | # another backsubstitution in the unwhitened case |
| 300 | if not white: |
| 301 | A = tf.linalg.triangular_solve( |
| 302 | Lm, A, adjoint=True |
| 303 | ) # [L, M, M] \ [L, M, N*P] -> [L, M, N*P] |
| 304 | Ar = tf.reshape(A, (L, M, N, P)) |
| 305 | |
| 306 | fmean = tf.tensordot(Ar, f, [[1, 0], [0, 1]]) # [N, P] |
| 307 | |
| 308 | if q_sqrt is not None: |
| 309 | if q_sqrt.shape.ndims == 3: |
| 310 | Lf = tf.linalg.band_part(q_sqrt, -1, 0) # [L, M, M] |
| 311 | LTA = tf.linalg.matmul( |
| 312 | Lf, A, transpose_a=True |
| 313 | ) # [L, M, M] * [L, M, P] -> [L, M, P] |
| 314 | else: # q_sqrt [M, L] |
searching dependent graphs…