Preconditioner of the form :math:`Q=(Q_ff + σ²I)⁻¹`, where L is lower triangular with :math: `LLᵀ = Kᵤᵤ` :math:`A = σ⁻²L⁻¹Kᵤₓ` and :math:`B = AAᵀ + I = LᵦLᵦᵀ`
| 294 | |
| 295 | |
| 296 | class NystromPreconditioner: |
| 297 | """ |
| 298 | Preconditioner of the form :math:`Q=(Q_ff + σ²I)⁻¹`, |
| 299 | where L is lower triangular with :math: `LLᵀ = Kᵤᵤ` |
| 300 | :math:`A = σ⁻²L⁻¹Kᵤₓ` and :math:`B = AAᵀ + I = LᵦLᵦᵀ` |
| 301 | """ |
| 302 | |
| 303 | @check_shapes( |
| 304 | "A: [M, N]", |
| 305 | "LB: [M, M]", |
| 306 | ) |
| 307 | def __init__(self, A: tf.Tensor, LB: tf.Tensor, sigma_sq: float) -> None: |
| 308 | self.A = A |
| 309 | self.LB = LB |
| 310 | self.sigma_sq = sigma_sq |
| 311 | |
| 312 | @check_shapes( |
| 313 | "v: [B, N]", |
| 314 | "return[0]: [B, N]", |
| 315 | "return[1]: []", |
| 316 | ) |
| 317 | def __call__(self, v: TensorType) -> Tuple[tf.Tensor, tf.Tensor]: |
| 318 | """ |
| 319 | Computes :math:`vᵀQ^{-1}` and `vᵀQ^{-1}v`. Note that this is |
| 320 | implemented as multipication of a row vector on the right. |
| 321 | |
| 322 | :param v: Vector we want to backsolve. |
| 323 | """ |
| 324 | sigma_sq = self.sigma_sq |
| 325 | A = self.A |
| 326 | LB = self.LB |
| 327 | |
| 328 | trans = tf.transpose |
| 329 | trisolve = tf.linalg.triangular_solve |
| 330 | matmul = tf.linalg.matmul |
| 331 | |
| 332 | v = trans(v) |
| 333 | Av = matmul(A, v) |
| 334 | LBinvAv = trisolve(LB, Av) |
| 335 | LBinvtLBinvAv = trisolve(trans(LB), LBinvAv, lower=False) |
| 336 | |
| 337 | rv = v - matmul(A, LBinvtLBinvAv, transpose_a=True) |
| 338 | vtrv = tf.reduce_sum(rv * v) |
| 339 | return trans(rv) / sigma_sq, vtrv / sigma_sq |
| 340 | |
| 341 | |
| 342 | @check_shapes( |
no outgoing calls
searching dependent graphs…