Compute the mean and variance of the latent function at some new points Xnew. For a derivation of the terms in here, see the associated SGPR notebook.
(
self, Xnew: InputData, full_cov: bool = False, full_output_cov: bool = False
)
| 290 | |
| 291 | @inherit_check_shapes |
| 292 | def predict_f( |
| 293 | self, Xnew: InputData, full_cov: bool = False, full_output_cov: bool = False |
| 294 | ) -> MeanAndVariance: |
| 295 | """ |
| 296 | Compute the mean and variance of the latent function at some new points |
| 297 | Xnew. For a derivation of the terms in here, see the associated SGPR |
| 298 | notebook. |
| 299 | """ |
| 300 | # could copy into posterior into a fused version |
| 301 | |
| 302 | assert_params_false(self.predict_f, full_output_cov=full_output_cov) |
| 303 | |
| 304 | X_data, Y_data = self.data |
| 305 | num_inducing = self.inducing_variable.num_inducing |
| 306 | err = Y_data - self.mean_function(X_data) |
| 307 | kuf = Kuf(self.inducing_variable, self.kernel, X_data) |
| 308 | kuu = Kuu(self.inducing_variable, self.kernel, jitter=default_jitter()) |
| 309 | Kus = Kuf(self.inducing_variable, self.kernel, Xnew) |
| 310 | |
| 311 | sigma_sq = tf.squeeze(self.likelihood.variance_at(X_data), axis=-1) |
| 312 | sigma = tf.sqrt(sigma_sq) |
| 313 | |
| 314 | L = tf.linalg.cholesky(kuu) # cache alpha, qinv |
| 315 | A = tf.linalg.triangular_solve(L, kuf / sigma, lower=True) |
| 316 | B = tf.linalg.matmul(A, A, transpose_b=True) + tf.eye( |
| 317 | num_inducing, dtype=default_float() |
| 318 | ) # cache qinv |
| 319 | LB = tf.linalg.cholesky(B) # cache alpha |
| 320 | Aerr = tf.linalg.matmul(A, err / sigma[..., None]) |
| 321 | c = tf.linalg.triangular_solve(LB, Aerr, lower=True) |
| 322 | tmp1 = tf.linalg.triangular_solve(L, Kus, lower=True) |
| 323 | tmp2 = tf.linalg.triangular_solve(LB, tmp1, lower=True) |
| 324 | mean = tf.linalg.matmul(tmp2, c, transpose_a=True) |
| 325 | if full_cov: |
| 326 | var = ( |
| 327 | self.kernel(Xnew) |
| 328 | + tf.linalg.matmul(tmp2, tmp2, transpose_a=True) |
| 329 | - tf.linalg.matmul(tmp1, tmp1, transpose_a=True) |
| 330 | ) |
| 331 | var = tf.tile(var[None, ...], [self.num_latent_gps, 1, 1]) # [P, N, N] |
| 332 | else: |
| 333 | var = ( |
| 334 | self.kernel(Xnew, full_cov=False) |
| 335 | + tf.reduce_sum(tf.square(tmp2), 0) |
| 336 | - tf.reduce_sum(tf.square(tmp1), 0) |
| 337 | ) |
| 338 | var = tf.tile(var[:, None], [1, self.num_latent_gps]) |
| 339 | |
| 340 | return mean + self.mean_function(Xnew), var |
| 341 | |
| 342 | @check_shapes( |
| 343 | "return[0]: [M, P]", |
nothing calls this directly
no test coverage detected