MCPcopy
hub / github.com/tensorlayer/TensorLayer / update

Method update

examples/reinforcement_learning/tutorial_PPO.py:155–186  ·  view source on GitHub ↗

Update parameter with the constraint of KL divergent :return: None

(self)

Source from the content-addressed store, hash-verified

153 self.critic_opt.apply_gradients(zip(grad, self.critic.trainable_weights))
154
155 def update(self):
156 """
157 Update parameter with the constraint of KL divergent
158 :return: None
159 """
160 s = np.array(self.state_buffer, np.float32)
161 a = np.array(self.action_buffer, np.float32)
162 r = np.array(self.cumulative_reward_buffer, np.float32)
163 mean, std = self.actor(s), tf.exp(self.actor.logstd)
164 pi = tfp.distributions.Normal(mean, std)
165 adv = r - self.critic(s)
166
167 # update actor
168 if self.method == 'kl_pen':
169 for _ in range(ACTOR_UPDATE_STEPS):
170 kl = self.train_actor(s, a, adv, pi)
171 if kl < self.kl_target / 1.5:
172 self.lam /= 2
173 elif kl > self.kl_target * 1.5:
174 self.lam *= 2
175 else:
176 for _ in range(ACTOR_UPDATE_STEPS):
177 self.train_actor(s, a, adv, pi)
178
179 # update critic
180 for _ in range(CRITIC_UPDATE_STEPS):
181 self.train_critic(r, s)
182
183 self.state_buffer.clear()
184 self.action_buffer.clear()
185 self.cumulative_reward_buffer.clear()
186 self.reward_buffer.clear()
187
188 def get_action(self, state, greedy=False):
189 """

Callers 15

_fill_project_infoMethod · 0.45
save_modelMethod · 0.45
find_top_modelMethod · 0.45
save_datasetMethod · 0.45
find_top_datasetMethod · 0.45
find_datasetsMethod · 0.45
save_training_logMethod · 0.45
save_validation_logMethod · 0.45
save_testing_logMethod · 0.45
create_taskMethod · 0.45
run_top_taskMethod · 0.45
check_unfinished_taskMethod · 0.45

Calls 2

train_actorMethod · 0.95
train_criticMethod · 0.95

Tested by

no test coverage detected