MCPcopy Index your code
hub / github.com/tensorlayer/TensorLayer / TRPO

Class TRPO

examples/reinforcement_learning/tutorial_TRPO.py:180–437  ·  view source on GitHub ↗

trpo class

Source from the content-addressed store, hash-verified

178
179
180class TRPO:
181 """
182 trpo class
183 """
184
185 def __init__(self, state_dim, action_dim, action_bound):
186 # critic
187 with tf.name_scope('critic'):
188 layer = input_layer = tl.layers.Input([None, state_dim], tf.float32)
189 for d in HIDDEN_SIZES:
190 layer = tl.layers.Dense(d, tf.nn.relu)(layer)
191 v = tl.layers.Dense(1)(layer)
192 self.critic = tl.models.Model(input_layer, v)
193 self.critic.train()
194
195 # actor
196 with tf.name_scope('actor'):
197 layer = input_layer = tl.layers.Input([None, state_dim], tf.float32)
198 for d in HIDDEN_SIZES:
199 layer = tl.layers.Dense(d, tf.nn.relu)(layer)
200 mean = tl.layers.Dense(action_dim, tf.nn.tanh)(layer)
201 mean = tl.layers.Lambda(lambda x: x * action_bound)(mean)
202 log_std = tf.Variable(np.zeros(action_dim, dtype=np.float32))
203
204 self.actor = tl.models.Model(input_layer, mean)
205 self.actor.trainable_weights.append(log_std)
206 self.actor.log_std = log_std
207 self.actor.train()
208
209 self.buf = GAE_Buffer(state_dim, action_dim, BATCH_SIZE, GAMMA, LAM)
210 self.critic_optimizer = tf.optimizers.Adam(learning_rate=VF_LR)
211 self.action_bound = action_bound
212
213 def get_action(self, state, greedy=False):
214 """
215 get action
216 :param state: state input
217 :param greedy: get action greedy or not
218 :return: pi, v, logp_pi, mean, log_std
219 """
220 state = np.array([state], np.float32)
221 mean = self.actor(state)
222 log_std = tf.convert_to_tensor(self.actor.log_std)
223 std = tf.exp(log_std)
224 std = tf.ones_like(mean) * std
225 pi = tfp.distributions.Normal(mean, std)
226
227 if greedy:
228 action = mean
229 else:
230 action = pi.sample()
231 action = np.clip(action, -self.action_bound, self.action_bound)
232 logp_pi = pi.log_prob(action)
233
234 value = self.critic(state)
235 return action[0], value, logp_pi, mean, log_std
236
237 def pi_loss(self, states, actions, adv, old_log_prob):

Callers 1

tutorial_TRPO.pyFile · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…