(state: CGState)
| 393 | return (0.5 * state.rz > cg_tolerance) and (state.i < max_steps) |
| 394 | |
| 395 | def cg_step(state: CGState) -> List[CGState]: |
| 396 | Ap = state.p @ K |
| 397 | denom = tf.reduce_sum(state.p * Ap, axis=-1) |
| 398 | gamma = state.rz / denom |
| 399 | v = state.v + gamma * state.p |
| 400 | i = state.i + 1 |
| 401 | r = tf.cond( |
| 402 | state.i % restart_cg_step == restart_cg_step - 1, |
| 403 | lambda: b - v @ K, |
| 404 | lambda: state.r - gamma * Ap, |
| 405 | ) |
| 406 | z, new_rz = preconditioner(r) |
| 407 | p = tf.cond( |
| 408 | state.i % restart_cg_step == restart_cg_step - 1, |
| 409 | lambda: z, |
| 410 | lambda: z + state.p * new_rz / state.rz, |
| 411 | ) |
| 412 | return [CGState(i, v, r, p, new_rz)] |
| 413 | |
| 414 | Kv = initial @ K |
| 415 | r = b - Kv |
nothing calls this directly
no test coverage detected
searching dependent graphs…