update trpo :return: None
(self)
| 388 | return hvp |
| 389 | |
| 390 | def update(self): |
| 391 | """ |
| 392 | update trpo |
| 393 | :return: None |
| 394 | """ |
| 395 | states, actions, adv, rewards_to_go, logp_old_ph, old_mu, old_log_std = self.buf.get() |
| 396 | g, pi_l_old = self.gradient(states, actions, adv, logp_old_ph) |
| 397 | |
| 398 | Hx = lambda x: self.hvp(states, old_mu, old_log_std, x) |
| 399 | x = self.cg(Hx, g) |
| 400 | |
| 401 | alpha = np.sqrt(2 * DELTA / (np.dot(x, Hx(x)) + EPS)) |
| 402 | old_params = self.get_pi_params() |
| 403 | |
| 404 | def set_and_eval(step): |
| 405 | params = old_params - alpha * x * step |
| 406 | self.set_pi_params(params) |
| 407 | d_kl = self.kl(states, old_mu, old_log_std) |
| 408 | loss = self.pi_loss(states, actions, adv, logp_old_ph) |
| 409 | return [d_kl, loss] |
| 410 | |
| 411 | # trpo with backtracking line search, hard kl |
| 412 | for j in range(BACKTRACK_ITERS): |
| 413 | kl, pi_l_new = set_and_eval(step=BACKTRACK_COEFF**j) |
| 414 | if kl <= DELTA and pi_l_new <= pi_l_old: |
| 415 | # Accepting new params at step of line search |
| 416 | break |
| 417 | else: |
| 418 | # Line search failed! Keeping old params. |
| 419 | set_and_eval(step=0.) |
| 420 | |
| 421 | # Value function updates |
| 422 | for _ in range(TRAIN_VF_ITERS): |
| 423 | self.train_vf(states, rewards_to_go) |
| 424 | |
| 425 | def finish_path(self, done, next_state): |
| 426 | """ |
no test coverage detected