Compute the mean and variance of the latent function at some new points. Note that this is very similar to the SGPR prediction, for which there are notes in the SGPR notebook. Note: This model does not allow full output covariances. :param Xnew: points at w
(
self, Xnew: InputData, full_cov: bool = False, full_output_cov: bool = False
)
| 217 | |
| 218 | @inherit_check_shapes |
| 219 | def predict_f( |
| 220 | self, Xnew: InputData, full_cov: bool = False, full_output_cov: bool = False |
| 221 | ) -> MeanAndVariance: |
| 222 | """ |
| 223 | Compute the mean and variance of the latent function at some new points. |
| 224 | Note that this is very similar to the SGPR prediction, for which |
| 225 | there are notes in the SGPR notebook. |
| 226 | |
| 227 | Note: This model does not allow full output covariances. |
| 228 | |
| 229 | :param Xnew: points at which to predict |
| 230 | """ |
| 231 | assert_params_false(self.predict_f, full_output_cov=full_output_cov) |
| 232 | |
| 233 | pX = DiagonalGaussian(self.X_data_mean, self.X_data_var) |
| 234 | |
| 235 | Y_data = self.data |
| 236 | num_inducing = self.inducing_variable.num_inducing |
| 237 | psi1 = expectation(pX, (self.kernel, self.inducing_variable)) |
| 238 | psi2 = tf.reduce_sum( |
| 239 | expectation( |
| 240 | pX, (self.kernel, self.inducing_variable), (self.kernel, self.inducing_variable) |
| 241 | ), |
| 242 | axis=0, |
| 243 | ) |
| 244 | jitter = default_jitter() |
| 245 | Kus = covariances.Kuf(self.inducing_variable, self.kernel, Xnew) |
| 246 | sigma2 = self.likelihood.variance |
| 247 | L = tf.linalg.cholesky(covariances.Kuu(self.inducing_variable, self.kernel, jitter=jitter)) |
| 248 | |
| 249 | A = tf.linalg.triangular_solve(L, tf.transpose(psi1), lower=True) |
| 250 | tmp = tf.linalg.triangular_solve(L, psi2, lower=True) |
| 251 | AAT = tf.linalg.triangular_solve(L, tf.transpose(tmp), lower=True) / sigma2 |
| 252 | B = AAT + tf.eye(num_inducing, dtype=default_float()) |
| 253 | LB = tf.linalg.cholesky(B) |
| 254 | c = tf.linalg.triangular_solve(LB, tf.linalg.matmul(A, Y_data), lower=True) / sigma2 |
| 255 | tmp1 = tf.linalg.triangular_solve(L, Kus, lower=True) |
| 256 | tmp2 = tf.linalg.triangular_solve(LB, tmp1, lower=True) |
| 257 | mean = tf.linalg.matmul(tmp2, c, transpose_a=True) |
| 258 | if full_cov: |
| 259 | var = ( |
| 260 | self.kernel(Xnew) |
| 261 | + tf.linalg.matmul(tmp2, tmp2, transpose_a=True) |
| 262 | - tf.linalg.matmul(tmp1, tmp1, transpose_a=True) |
| 263 | ) |
| 264 | shape = tf.stack([tf.shape(Y_data)[1], 1, 1]) |
| 265 | var = tf.tile(tf.expand_dims(var, 0), shape) |
| 266 | else: |
| 267 | var = ( |
| 268 | self.kernel(Xnew, full_cov=False) |
| 269 | + tf.reduce_sum(tf.square(tmp2), axis=0) |
| 270 | - tf.reduce_sum(tf.square(tmp1), axis=0) |
| 271 | ) |
| 272 | shape = tf.stack([1, tf.shape(Y_data)[1]]) |
| 273 | var = tf.tile(tf.expand_dims(var, 1), shape) |
| 274 | return mean + self.mean_function(Xnew), var |
| 275 | |
| 276 | @inherit_check_shapes |