MCPcopy
hub / github.com/GPflow/GPflow / cglb_conjugate_gradient

Function cglb_conjugate_gradient

gpflow/models/cglb.py:348–422  ·  view source on GitHub ↗

Conjugate gradient algorithm used in CGLB model. The method of conjugate gradient (Hestenes and Stiefel, 1952) produces a sequence of vectors :math:`v_0, v_1, v_2, ..., v_N` such that :math:`v_0` = initial, and (in exact arithmetic) :math:`Kv_n = b`. In practice, the v_i often c

(
    K: TensorType,
    b: TensorType,
    initial: TensorType,
    preconditioner: NystromPreconditioner,
    cg_tolerance: float,
    max_steps: int,
    restart_cg_step: int,
)

Source from the content-addressed store, hash-verified

346 "return: [P, N]",
347)
348def cglb_conjugate_gradient(
349 K: TensorType,
350 b: TensorType,
351 initial: TensorType,
352 preconditioner: NystromPreconditioner,
353 cg_tolerance: float,
354 max_steps: int,
355 restart_cg_step: int,
356) -> tf.Tensor:
357 """
358 Conjugate gradient algorithm used in CGLB model. The method of
359 conjugate gradient (Hestenes and Stiefel, 1952) produces a
360 sequence of vectors :math:`v_0, v_1, v_2, ..., v_N` such that
361 :math:`v_0` = initial, and (in exact arithmetic)
362 :math:`Kv_n = b`. In practice, the v_i often converge quickly to
363 approximate :math:`K^{-1}b`, and the algorithm can be stopped
364 without running N iterations.
365
366 We assume the preconditioner, :math:`Q`, satisfies :math:`Q ≺ K`,
367 and stop the algorithm when :math:`r_i = b - Kv_i` satisfies
368 :math:`||rᵢᵀ||_{Q⁻¹r}^2 = rᵢᵀQ⁻¹rᵢ <= ϵ`.
369
370 :param K: Matrix we want to backsolve from. Must be PSD.
371 :param b: Vector we want to backsolve.
372 :param initial: Initial vector solution.
373 :param preconditioner: Preconditioner function.
374 :param cg_tolerance: Expected maximum error. This value is used
375 as a decision boundary against stopping criteria.
376 :param max_steps: Maximum number of CG iterations.
377 :param restart_cg_step: Restart step at which the CG resets the
378 internal state to the initial position using the currect
379 solution vector :math:`v`. Can help avoid build up of
380 numerical errors.
381
382 :return: `v` where `v` approximately satisfies :math:`Kv = b`.
383 """
384
385 class CGState(NamedTuple):
386 i: tf.Tensor
387 v: tf.Tensor
388 r: tf.Tensor
389 p: tf.Tensor
390 rz: tf.Tensor
391
392 def stopping_criterion(state: CGState) -> tf.Tensor:
393 return (0.5 * state.rz > cg_tolerance) and (state.i < max_steps)
394
395 def cg_step(state: CGState) -> List[CGState]:
396 Ap = state.p @ K
397 denom = tf.reduce_sum(state.p * Ap, axis=-1)
398 gamma = state.rz / denom
399 v = state.v + gamma * state.p
400 i = state.i + 1
401 r = tf.cond(
402 state.i % restart_cg_step == restart_cg_step - 1,
403 lambda: b - v @ K,
404 lambda: state.r - gamma * Ap,
405 )

Callers 3

quad_termMethod · 0.85
predict_fMethod · 0.85

Calls 2

default_intFunction · 0.85
CGStateClass · 0.85

Tested by 1

Used in the wild real call sites across dependent graphs

searching dependent graphs…