MCPcopy
hub / github.com/tensorlayer/TensorLayer / __init__

Method __init__

examples/reinforcement_learning/tutorial_TD3.py:212–246  ·  view source on GitHub ↗
(
            self, state_dim, action_dim, action_range, hidden_dim, replay_buffer, policy_target_update_interval=1,
            q_lr=3e-4, policy_lr=3e-4
    )

Source from the content-addressed store, hash-verified

210class TD3:
211
212 def __init__(
213 self, state_dim, action_dim, action_range, hidden_dim, replay_buffer, policy_target_update_interval=1,
214 q_lr=3e-4, policy_lr=3e-4
215 ):
216 self.replay_buffer = replay_buffer
217
218 # initialize all networks
219 self.q_net1 = QNetwork(state_dim, action_dim, hidden_dim)
220 self.q_net2 = QNetwork(state_dim, action_dim, hidden_dim)
221 self.target_q_net1 = QNetwork(state_dim, action_dim, hidden_dim)
222 self.target_q_net2 = QNetwork(state_dim, action_dim, hidden_dim)
223 self.policy_net = PolicyNetwork(state_dim, action_dim, hidden_dim, action_range)
224 self.target_policy_net = PolicyNetwork(state_dim, action_dim, hidden_dim, action_range)
225 print('Q Network (1,2): ', self.q_net1)
226 print('Policy Network: ', self.policy_net)
227
228 # initialize weights of target networks
229 self.target_q_net1 = self.target_ini(self.q_net1, self.target_q_net1)
230 self.target_q_net2 = self.target_ini(self.q_net2, self.target_q_net2)
231 self.target_policy_net = self.target_ini(self.policy_net, self.target_policy_net)
232
233 # set train mode
234 self.q_net1.train()
235 self.q_net2.train()
236 self.target_q_net1.eval()
237 self.target_q_net2.eval()
238 self.policy_net.train()
239 self.target_policy_net.eval()
240
241 self.update_cnt = 0
242 self.policy_target_update_interval = policy_target_update_interval
243
244 self.q_optimizer1 = tf.optimizers.Adam(q_lr)
245 self.q_optimizer2 = tf.optimizers.Adam(q_lr)
246 self.policy_optimizer = tf.optimizers.Adam(policy_lr)
247
248 def target_ini(self, net, target_net):
249 """ hard-copy update for initializing target networks """

Callers 2

__init__Method · 0.45
__init__Method · 0.45

Calls 5

target_iniMethod · 0.95
QNetworkClass · 0.85
evalMethod · 0.80
PolicyNetworkClass · 0.70
trainMethod · 0.45

Tested by

no test coverage detected