Conjugate gradient algorithm (see https://en.wikipedia.org/wiki/Conjugate_gradient_method)
(self, Ax, b)
| 349 | tl.files.load_hdf5_to_weights_in_order(os.path.join(path, 'critic.hdf5'), self.critic) |
| 350 | |
| 351 | def cg(self, Ax, b): |
| 352 | """ |
| 353 | Conjugate gradient algorithm |
| 354 | (see https://en.wikipedia.org/wiki/Conjugate_gradient_method) |
| 355 | """ |
| 356 | x = np.zeros_like(b) |
| 357 | r = copy.deepcopy(b) # Note: should be 'b - Ax(x)', but for x=0, Ax(x)=0. Change if doing warm start. |
| 358 | p = copy.deepcopy(r) |
| 359 | r_dot_old = np.dot(r, r) |
| 360 | for _ in range(CG_ITERS): |
| 361 | z = Ax(p) |
| 362 | alpha = r_dot_old / (np.dot(p, z) + EPS) |
| 363 | x += alpha * p |
| 364 | r -= alpha * z |
| 365 | r_dot_new = np.dot(r, r) |
| 366 | p = r + (r_dot_new / r_dot_old) * p |
| 367 | r_dot_old = r_dot_new |
| 368 | return x |
| 369 | |
| 370 | def hvp(self, states, old_mean, old_log_std, x): |
| 371 | """ |