calculate Hessian-vector product :param states: state batch :param old_mean: mean batch of the old pi :param old_log_std: log std batch of the old pi :return: hvp
(self, states, old_mean, old_log_std, x)
| 368 | return x |
| 369 | |
| 370 | def hvp(self, states, old_mean, old_log_std, x): |
| 371 | """ |
| 372 | calculate Hessian-vector product |
| 373 | :param states: state batch |
| 374 | :param old_mean: mean batch of the old pi |
| 375 | :param old_log_std: log std batch of the old pi |
| 376 | :return: hvp |
| 377 | """ |
| 378 | pi_params = self.actor.trainable_weights |
| 379 | with tf.GradientTape() as tape1: |
| 380 | with tf.GradientTape() as tape0: |
| 381 | d_kl = self.kl(states, old_mean, old_log_std) |
| 382 | g = self._flat_concat(tape0.gradient(d_kl, pi_params)) |
| 383 | l = tf.reduce_sum(g * x) |
| 384 | hvp = self._flat_concat(tape1.gradient(l, pi_params)) |
| 385 | |
| 386 | if DAMPING_COEFF > 0: |
| 387 | hvp += DAMPING_COEFF * x |
| 388 | return hvp |
| 389 | |
| 390 | def update(self): |
| 391 | """ |
no test coverage detected