Update parameter with the constraint of KL divergent :return: None
(self)
| 153 | self.critic_opt.apply_gradients(zip(grad, self.critic.trainable_weights)) |
| 154 | |
| 155 | def update(self): |
| 156 | """ |
| 157 | Update parameter with the constraint of KL divergent |
| 158 | :return: None |
| 159 | """ |
| 160 | s = np.array(self.state_buffer, np.float32) |
| 161 | a = np.array(self.action_buffer, np.float32) |
| 162 | r = np.array(self.cumulative_reward_buffer, np.float32) |
| 163 | mean, std = self.actor(s), tf.exp(self.actor.logstd) |
| 164 | pi = tfp.distributions.Normal(mean, std) |
| 165 | adv = r - self.critic(s) |
| 166 | |
| 167 | # update actor |
| 168 | if self.method == 'kl_pen': |
| 169 | for _ in range(ACTOR_UPDATE_STEPS): |
| 170 | kl = self.train_actor(s, a, adv, pi) |
| 171 | if kl < self.kl_target / 1.5: |
| 172 | self.lam /= 2 |
| 173 | elif kl > self.kl_target * 1.5: |
| 174 | self.lam *= 2 |
| 175 | else: |
| 176 | for _ in range(ACTOR_UPDATE_STEPS): |
| 177 | self.train_actor(s, a, adv, pi) |
| 178 | |
| 179 | # update critic |
| 180 | for _ in range(CRITIC_UPDATE_STEPS): |
| 181 | self.train_critic(r, s) |
| 182 | |
| 183 | self.state_buffer.clear() |
| 184 | self.action_buffer.clear() |
| 185 | self.cumulative_reward_buffer.clear() |
| 186 | self.reward_buffer.clear() |
| 187 | |
| 188 | def get_action(self, state, greedy=False): |
| 189 | """ |
no test coverage detected