get action :param state: state input :param greedy: get action greedy or not :return: pi, v, logp_pi, mean, log_std
(self, state, greedy=False)
| 211 | self.action_bound = action_bound |
| 212 | |
| 213 | def get_action(self, state, greedy=False): |
| 214 | """ |
| 215 | get action |
| 216 | :param state: state input |
| 217 | :param greedy: get action greedy or not |
| 218 | :return: pi, v, logp_pi, mean, log_std |
| 219 | """ |
| 220 | state = np.array([state], np.float32) |
| 221 | mean = self.actor(state) |
| 222 | log_std = tf.convert_to_tensor(self.actor.log_std) |
| 223 | std = tf.exp(log_std) |
| 224 | std = tf.ones_like(mean) * std |
| 225 | pi = tfp.distributions.Normal(mean, std) |
| 226 | |
| 227 | if greedy: |
| 228 | action = mean |
| 229 | else: |
| 230 | action = pi.sample() |
| 231 | action = np.clip(action, -self.action_bound, self.action_bound) |
| 232 | logp_pi = pi.log_prob(action) |
| 233 | |
| 234 | value = self.critic(state) |
| 235 | return action[0], value, logp_pi, mean, log_std |
| 236 | |
| 237 | def pi_loss(self, states, actions, adv, old_log_prob): |
| 238 | """ |