The posterior mean for CGLB model is given by .. :math:: m(xs) = K_{sf}v + Q_{ff}Q⁻¹r where :math:`r = y - K v` is the residual from CG. Note that when :math:`v=0`, this agree with the SGPR mean, while if :math:`v = K⁻¹ y`, then :math:`r=0`, a
(
self,
Xnew: InputData,
full_cov: bool = False,
full_output_cov: bool = False,
cg_tolerance: Optional[float] = 1e-3,
)
| 171 | |
| 172 | @inherit_check_shapes |
| 173 | def predict_f( |
| 174 | self, |
| 175 | Xnew: InputData, |
| 176 | full_cov: bool = False, |
| 177 | full_output_cov: bool = False, |
| 178 | cg_tolerance: Optional[float] = 1e-3, |
| 179 | ) -> MeanAndVariance: |
| 180 | """ |
| 181 | The posterior mean for CGLB model is given by |
| 182 | |
| 183 | .. :math:: |
| 184 | m(xs) = K_{sf}v + Q_{ff}Q⁻¹r |
| 185 | |
| 186 | where :math:`r = y - K v` is the residual from CG. |
| 187 | |
| 188 | Note that when :math:`v=0`, this agree with the SGPR mean, |
| 189 | while if :math:`v = K⁻¹ y`, then :math:`r=0`, and the exact |
| 190 | GP mean is recovered. |
| 191 | |
| 192 | :param cg_tolerance: float or None: If None, the cached value of |
| 193 | :math:`v` is used. If float, conjugate gradient is run until :math:`rᵀQ⁻¹r < ϵ`. |
| 194 | """ |
| 195 | assert_params_false(self.predict_f, full_output_cov=full_output_cov) |
| 196 | |
| 197 | x, y = self.data |
| 198 | err = y - self.mean_function(x) |
| 199 | kxx = self.kernel(x, x) |
| 200 | ksf = self.kernel(Xnew, x) |
| 201 | sigma_sq = self.likelihood.variance |
| 202 | sigma = tf.sqrt(sigma_sq) |
| 203 | iv = self.inducing_variable |
| 204 | kernel = self.kernel |
| 205 | matmul = tf.linalg.matmul |
| 206 | trisolve = tf.linalg.triangular_solve |
| 207 | |
| 208 | kmat = add_noise_cov(kxx, sigma_sq) |
| 209 | |
| 210 | common = self._common_calculation() |
| 211 | A, LB, L = common.A, common.LB, common.L |
| 212 | |
| 213 | v = self.aux_vec |
| 214 | if cg_tolerance is not None: |
| 215 | preconditioner = NystromPreconditioner(A, LB, sigma_sq) |
| 216 | err_t = tf.transpose(err) |
| 217 | v = cglb_conjugate_gradient( |
| 218 | kmat, |
| 219 | err_t, |
| 220 | v, |
| 221 | preconditioner, |
| 222 | cg_tolerance, |
| 223 | self._max_cg_iters, |
| 224 | self._restart_cg_iters, |
| 225 | ) |
| 226 | |
| 227 | cg_mean = matmul(ksf, v, transpose_b=True) |
| 228 | res = err - matmul(kmat, v, transpose_b=True) |
| 229 | |
| 230 | Kus = Kuf(iv, kernel, Xnew) |