MCPcopy
hub / github.com/tensorlayer/TensorLayer / TD3

Class TD3

examples/reinforcement_learning/tutorial_TD3.py:210–336  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

208
209
210class TD3:
211
212 def __init__(
213 self, state_dim, action_dim, action_range, hidden_dim, replay_buffer, policy_target_update_interval=1,
214 q_lr=3e-4, policy_lr=3e-4
215 ):
216 self.replay_buffer = replay_buffer
217
218 # initialize all networks
219 self.q_net1 = QNetwork(state_dim, action_dim, hidden_dim)
220 self.q_net2 = QNetwork(state_dim, action_dim, hidden_dim)
221 self.target_q_net1 = QNetwork(state_dim, action_dim, hidden_dim)
222 self.target_q_net2 = QNetwork(state_dim, action_dim, hidden_dim)
223 self.policy_net = PolicyNetwork(state_dim, action_dim, hidden_dim, action_range)
224 self.target_policy_net = PolicyNetwork(state_dim, action_dim, hidden_dim, action_range)
225 print('Q Network (1,2): ', self.q_net1)
226 print('Policy Network: ', self.policy_net)
227
228 # initialize weights of target networks
229 self.target_q_net1 = self.target_ini(self.q_net1, self.target_q_net1)
230 self.target_q_net2 = self.target_ini(self.q_net2, self.target_q_net2)
231 self.target_policy_net = self.target_ini(self.policy_net, self.target_policy_net)
232
233 # set train mode
234 self.q_net1.train()
235 self.q_net2.train()
236 self.target_q_net1.eval()
237 self.target_q_net2.eval()
238 self.policy_net.train()
239 self.target_policy_net.eval()
240
241 self.update_cnt = 0
242 self.policy_target_update_interval = policy_target_update_interval
243
244 self.q_optimizer1 = tf.optimizers.Adam(q_lr)
245 self.q_optimizer2 = tf.optimizers.Adam(q_lr)
246 self.policy_optimizer = tf.optimizers.Adam(policy_lr)
247
248 def target_ini(self, net, target_net):
249 """ hard-copy update for initializing target networks """
250 for target_param, param in zip(target_net.trainable_weights, net.trainable_weights):
251 target_param.assign(param)
252 return target_net
253
254 def target_soft_update(self, net, target_net, soft_tau):
255 """ soft update the target net with Polyak averaging """
256 for target_param, param in zip(target_net.trainable_weights, net.trainable_weights):
257 target_param.assign( # copy weight value into target parameters
258 target_param * (1.0 - soft_tau) + param * soft_tau
259 )
260 return target_net
261
262 def update(self, batch_size, eval_noise_scale, reward_scale=10., gamma=0.9, soft_tau=1e-2):
263 """ update all networks in TD3 """
264 self.update_cnt += 1
265 state, action, reward, next_state, done = self.replay_buffer.sample(batch_size)
266
267 reward = reward[:, np.newaxis] # expand dim

Callers 1

tutorial_TD3.pyFile · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected