calculate kl-divergence :param states: state batch :param old_mean: mean batch of the old pi :param old_log_std: log std batch of the old pi :return: kl_mean or None
(self, states, old_mean, old_log_std)
| 280 | self.critic_optimizer.apply_gradients(zip(grad, self.critic.trainable_weights)) |
| 281 | |
| 282 | def kl(self, states, old_mean, old_log_std): |
| 283 | """ |
| 284 | calculate kl-divergence |
| 285 | :param states: state batch |
| 286 | :param old_mean: mean batch of the old pi |
| 287 | :param old_log_std: log std batch of the old pi |
| 288 | :return: kl_mean or None |
| 289 | """ |
| 290 | old_mean = old_mean[:, np.newaxis] |
| 291 | old_log_std = old_log_std[:, np.newaxis] |
| 292 | old_std = tf.exp(old_log_std) |
| 293 | old_pi = tfp.distributions.Normal(old_mean, old_std) |
| 294 | |
| 295 | mean = self.actor(states) |
| 296 | std = tf.exp(self.actor.log_std) * tf.ones_like(mean) |
| 297 | pi = tfp.distributions.Normal(mean, std) |
| 298 | |
| 299 | kl = tfp.distributions.kl_divergence(pi, old_pi) |
| 300 | all_kls = tf.reduce_sum(kl, axis=1) |
| 301 | return tf.reduce_mean(all_kls) |
| 302 | |
| 303 | def _flat_concat(self, xs): |
| 304 | """ |
no outgoing calls
no test coverage detected