pi gradients :param states: state batch :param actions: actions batch :param adv: advantage batch :param old_log_prob: old log probability batch :return: gradient
(self, states, actions, adv, old_log_prob)
| 251 | return -surr |
| 252 | |
| 253 | def gradient(self, states, actions, adv, old_log_prob): |
| 254 | """ |
| 255 | pi gradients |
| 256 | :param states: state batch |
| 257 | :param actions: actions batch |
| 258 | :param adv: advantage batch |
| 259 | :param old_log_prob: old log probability batch |
| 260 | :return: gradient |
| 261 | """ |
| 262 | pi_params = self.actor.trainable_weights |
| 263 | with tf.GradientTape() as tape: |
| 264 | loss = self.pi_loss(states, actions, adv, old_log_prob) |
| 265 | grad = tape.gradient(loss, pi_params) |
| 266 | gradient = self._flat_concat(grad) |
| 267 | return gradient, loss |
| 268 | |
| 269 | def train_vf(self, states, rewards_to_go): |
| 270 | """ |