Conjugate gradient algorithm used in CGLB model. The method of conjugate gradient (Hestenes and Stiefel, 1952) produces a sequence of vectors :math:`v_0, v_1, v_2, ..., v_N` such that :math:`v_0` = initial, and (in exact arithmetic) :math:`Kv_n = b`. In practice, the v_i often c
(
K: TensorType,
b: TensorType,
initial: TensorType,
preconditioner: NystromPreconditioner,
cg_tolerance: float,
max_steps: int,
restart_cg_step: int,
)
| 346 | "return: [P, N]", |
| 347 | ) |
| 348 | def cglb_conjugate_gradient( |
| 349 | K: TensorType, |
| 350 | b: TensorType, |
| 351 | initial: TensorType, |
| 352 | preconditioner: NystromPreconditioner, |
| 353 | cg_tolerance: float, |
| 354 | max_steps: int, |
| 355 | restart_cg_step: int, |
| 356 | ) -> tf.Tensor: |
| 357 | """ |
| 358 | Conjugate gradient algorithm used in CGLB model. The method of |
| 359 | conjugate gradient (Hestenes and Stiefel, 1952) produces a |
| 360 | sequence of vectors :math:`v_0, v_1, v_2, ..., v_N` such that |
| 361 | :math:`v_0` = initial, and (in exact arithmetic) |
| 362 | :math:`Kv_n = b`. In practice, the v_i often converge quickly to |
| 363 | approximate :math:`K^{-1}b`, and the algorithm can be stopped |
| 364 | without running N iterations. |
| 365 | |
| 366 | We assume the preconditioner, :math:`Q`, satisfies :math:`Q ≺ K`, |
| 367 | and stop the algorithm when :math:`r_i = b - Kv_i` satisfies |
| 368 | :math:`||rᵢᵀ||_{Q⁻¹r}^2 = rᵢᵀQ⁻¹rᵢ <= ϵ`. |
| 369 | |
| 370 | :param K: Matrix we want to backsolve from. Must be PSD. |
| 371 | :param b: Vector we want to backsolve. |
| 372 | :param initial: Initial vector solution. |
| 373 | :param preconditioner: Preconditioner function. |
| 374 | :param cg_tolerance: Expected maximum error. This value is used |
| 375 | as a decision boundary against stopping criteria. |
| 376 | :param max_steps: Maximum number of CG iterations. |
| 377 | :param restart_cg_step: Restart step at which the CG resets the |
| 378 | internal state to the initial position using the currect |
| 379 | solution vector :math:`v`. Can help avoid build up of |
| 380 | numerical errors. |
| 381 | |
| 382 | :return: `v` where `v` approximately satisfies :math:`Kv = b`. |
| 383 | """ |
| 384 | |
| 385 | class CGState(NamedTuple): |
| 386 | i: tf.Tensor |
| 387 | v: tf.Tensor |
| 388 | r: tf.Tensor |
| 389 | p: tf.Tensor |
| 390 | rz: tf.Tensor |
| 391 | |
| 392 | def stopping_criterion(state: CGState) -> tf.Tensor: |
| 393 | return (0.5 * state.rz > cg_tolerance) and (state.i < max_steps) |
| 394 | |
| 395 | def cg_step(state: CGState) -> List[CGState]: |
| 396 | Ap = state.p @ K |
| 397 | denom = tf.reduce_sum(state.p * Ap, axis=-1) |
| 398 | gamma = state.rz / denom |
| 399 | v = state.v + gamma * state.p |
| 400 | i = state.i + 1 |
| 401 | r = tf.cond( |
| 402 | state.i % restart_cg_step == restart_cg_step - 1, |
| 403 | lambda: b - v @ K, |
| 404 | lambda: state.r - gamma * Ap, |
| 405 | ) |
searching dependent graphs…