trpo class
| 178 | |
| 179 | |
| 180 | class TRPO: |
| 181 | """ |
| 182 | trpo class |
| 183 | """ |
| 184 | |
| 185 | def __init__(self, state_dim, action_dim, action_bound): |
| 186 | # critic |
| 187 | with tf.name_scope('critic'): |
| 188 | layer = input_layer = tl.layers.Input([None, state_dim], tf.float32) |
| 189 | for d in HIDDEN_SIZES: |
| 190 | layer = tl.layers.Dense(d, tf.nn.relu)(layer) |
| 191 | v = tl.layers.Dense(1)(layer) |
| 192 | self.critic = tl.models.Model(input_layer, v) |
| 193 | self.critic.train() |
| 194 | |
| 195 | # actor |
| 196 | with tf.name_scope('actor'): |
| 197 | layer = input_layer = tl.layers.Input([None, state_dim], tf.float32) |
| 198 | for d in HIDDEN_SIZES: |
| 199 | layer = tl.layers.Dense(d, tf.nn.relu)(layer) |
| 200 | mean = tl.layers.Dense(action_dim, tf.nn.tanh)(layer) |
| 201 | mean = tl.layers.Lambda(lambda x: x * action_bound)(mean) |
| 202 | log_std = tf.Variable(np.zeros(action_dim, dtype=np.float32)) |
| 203 | |
| 204 | self.actor = tl.models.Model(input_layer, mean) |
| 205 | self.actor.trainable_weights.append(log_std) |
| 206 | self.actor.log_std = log_std |
| 207 | self.actor.train() |
| 208 | |
| 209 | self.buf = GAE_Buffer(state_dim, action_dim, BATCH_SIZE, GAMMA, LAM) |
| 210 | self.critic_optimizer = tf.optimizers.Adam(learning_rate=VF_LR) |
| 211 | self.action_bound = action_bound |
| 212 | |
| 213 | def get_action(self, state, greedy=False): |
| 214 | """ |
| 215 | get action |
| 216 | :param state: state input |
| 217 | :param greedy: get action greedy or not |
| 218 | :return: pi, v, logp_pi, mean, log_std |
| 219 | """ |
| 220 | state = np.array([state], np.float32) |
| 221 | mean = self.actor(state) |
| 222 | log_std = tf.convert_to_tensor(self.actor.log_std) |
| 223 | std = tf.exp(log_std) |
| 224 | std = tf.ones_like(mean) * std |
| 225 | pi = tfp.distributions.Normal(mean, std) |
| 226 | |
| 227 | if greedy: |
| 228 | action = mean |
| 229 | else: |
| 230 | action = pi.sample() |
| 231 | action = np.clip(action, -self.action_bound, self.action_bound) |
| 232 | logp_pi = pi.log_prob(action) |
| 233 | |
| 234 | value = self.critic(state) |
| 235 | return action[0], value, logp_pi, mean, log_std |
| 236 | |
| 237 | def pi_loss(self, states, actions, adv, old_log_prob): |
no outgoing calls
no test coverage detected
searching dependent graphs…