Matrices used in log-det calculation :return: * :math:`σ²`, * :math:`σ`, * :math:`A = L⁻¹K_{uf}/σ`, where :math:`LLᵀ = Kᵤᵤ`, * :math:`B = AAT+I`, * :math:`LB` where :math`LBLBᵀ = B`, * :math:`AAT = AAᵀ`,
(self)
| 179 | "return.AAT: [M, M]", |
| 180 | ) |
| 181 | def _common_calculation(self) -> "SGPR.CommonTensors": |
| 182 | """ |
| 183 | Matrices used in log-det calculation |
| 184 | |
| 185 | :return: |
| 186 | * :math:`σ²`, |
| 187 | * :math:`σ`, |
| 188 | * :math:`A = L⁻¹K_{uf}/σ`, where :math:`LLᵀ = Kᵤᵤ`, |
| 189 | * :math:`B = AAT+I`, |
| 190 | * :math:`LB` where :math`LBLBᵀ = B`, |
| 191 | * :math:`AAT = AAᵀ`, |
| 192 | """ |
| 193 | x, _ = self.data # [N] |
| 194 | iv = self.inducing_variable # [M] |
| 195 | |
| 196 | sigma_sq = tf.squeeze(self.likelihood.variance_at(x), axis=-1) # [N] |
| 197 | sigma = tf.sqrt(sigma_sq) # [N] |
| 198 | |
| 199 | kuf = Kuf(iv, self.kernel, x) # [M, N] |
| 200 | kuu = Kuu(iv, self.kernel, jitter=default_jitter()) # [M, M] |
| 201 | L = tf.linalg.cholesky(kuu) # [M, M] |
| 202 | |
| 203 | # Compute intermediate matrices |
| 204 | A = tf.linalg.triangular_solve(L, kuf / sigma, lower=True) |
| 205 | AAT = tf.linalg.matmul(A, A, transpose_b=True) |
| 206 | B = add_noise_cov(AAT, tf.cast(1.0, AAT.dtype)) |
| 207 | LB = tf.linalg.cholesky(B) |
| 208 | |
| 209 | return self.CommonTensors(sigma_sq, sigma, A, B, LB, AAT, L) |
| 210 | |
| 211 | @check_shapes( |
| 212 | "return: []", |