calculate pi loss :param states: state batch :param actions: action batch :param adv: advantage batch :param old_log_prob: old log probability :return: pi loss
(self, states, actions, adv, old_log_prob)
| 235 | return action[0], value, logp_pi, mean, log_std |
| 236 | |
| 237 | def pi_loss(self, states, actions, adv, old_log_prob): |
| 238 | """ |
| 239 | calculate pi loss |
| 240 | :param states: state batch |
| 241 | :param actions: action batch |
| 242 | :param adv: advantage batch |
| 243 | :param old_log_prob: old log probability |
| 244 | :return: pi loss |
| 245 | """ |
| 246 | mean = self.actor(states) |
| 247 | pi = tfp.distributions.Normal(mean, tf.exp(self.actor.log_std)) |
| 248 | log_prob = pi.log_prob(actions)[:, 0] |
| 249 | ratio = tf.exp(log_prob - old_log_prob) |
| 250 | surr = tf.reduce_mean(ratio * adv) |
| 251 | return -surr |
| 252 | |
| 253 | def gradient(self, states, actions, adv, old_log_prob): |
| 254 | """ |
no outgoing calls
no test coverage detected