MCPcopy
hub / github.com/tensorlayer/TensorLayer / kl

Method kl

examples/reinforcement_learning/tutorial_TRPO.py:282–301  ·  view source on GitHub ↗

calculate kl-divergence :param states: state batch :param old_mean: mean batch of the old pi :param old_log_std: log std batch of the old pi :return: kl_mean or None

(self, states, old_mean, old_log_std)

Source from the content-addressed store, hash-verified

280 self.critic_optimizer.apply_gradients(zip(grad, self.critic.trainable_weights))
281
282 def kl(self, states, old_mean, old_log_std):
283 """
284 calculate kl-divergence
285 :param states: state batch
286 :param old_mean: mean batch of the old pi
287 :param old_log_std: log std batch of the old pi
288 :return: kl_mean or None
289 """
290 old_mean = old_mean[:, np.newaxis]
291 old_log_std = old_log_std[:, np.newaxis]
292 old_std = tf.exp(old_log_std)
293 old_pi = tfp.distributions.Normal(old_mean, old_std)
294
295 mean = self.actor(states)
296 std = tf.exp(self.actor.log_std) * tf.ones_like(mean)
297 pi = tfp.distributions.Normal(mean, std)
298
299 kl = tfp.distributions.kl_divergence(pi, old_pi)
300 all_kls = tf.reduce_sum(kl, axis=1)
301 return tf.reduce_mean(all_kls)
302
303 def _flat_concat(self, xs):
304 """

Callers 2

hvpMethod · 0.95
set_and_evalMethod · 0.95

Calls

no outgoing calls

Tested by

no test coverage detected