Update policy network :param state: state batch :param action: action batch :param adv: advantage batch :param old_pi: old pi distribution :return: kl_mean or None
(self, state, action, adv, old_pi)
| 109 | self.action_bound = action_bound |
| 110 | |
| 111 | def train_actor(self, state, action, adv, old_pi): |
| 112 | """ |
| 113 | Update policy network |
| 114 | :param state: state batch |
| 115 | :param action: action batch |
| 116 | :param adv: advantage batch |
| 117 | :param old_pi: old pi distribution |
| 118 | :return: kl_mean or None |
| 119 | """ |
| 120 | with tf.GradientTape() as tape: |
| 121 | mean, std = self.actor(state), tf.exp(self.actor.logstd) |
| 122 | pi = tfp.distributions.Normal(mean, std) |
| 123 | |
| 124 | ratio = tf.exp(pi.log_prob(action) - old_pi.log_prob(action)) |
| 125 | surr = ratio * adv |
| 126 | if self.method == 'penalty': # ppo penalty |
| 127 | kl = tfp.distributions.kl_divergence(old_pi, pi) |
| 128 | kl_mean = tf.reduce_mean(kl) |
| 129 | loss = -(tf.reduce_mean(surr - self.lam * kl)) |
| 130 | else: # ppo clip |
| 131 | loss = -tf.reduce_mean( |
| 132 | tf.minimum(surr, |
| 133 | tf.clip_by_value(ratio, 1. - self.epsilon, 1. + self.epsilon) * adv) |
| 134 | ) |
| 135 | a_gard = tape.gradient(loss, self.actor.trainable_weights) |
| 136 | self.actor_opt.apply_gradients(zip(a_gard, self.actor.trainable_weights)) |
| 137 | |
| 138 | if self.method == 'kl_pen': |
| 139 | return kl_mean |
| 140 | |
| 141 | def train_critic(self, reward, state): |
| 142 | """ |