MCPcopy
hub / github.com/tensorlayer/TensorLayer / __init__

Method __init__

examples/reinforcement_learning/tutorial_SAC.py:223–259  ·  view source on GitHub ↗
(
            self, state_dim, action_dim, action_range, hidden_dim, replay_buffer, SOFT_Q_LR=3e-4, POLICY_LR=3e-4,
            ALPHA_LR=3e-4
    )

Source from the content-addressed store, hash-verified

221class SAC:
222
223 def __init__(
224 self, state_dim, action_dim, action_range, hidden_dim, replay_buffer, SOFT_Q_LR=3e-4, POLICY_LR=3e-4,
225 ALPHA_LR=3e-4
226 ):
227 self.replay_buffer = replay_buffer
228
229 # initialize all networks
230 self.soft_q_net1 = SoftQNetwork(state_dim, action_dim, hidden_dim)
231 self.soft_q_net2 = SoftQNetwork(state_dim, action_dim, hidden_dim)
232 self.target_soft_q_net1 = SoftQNetwork(state_dim, action_dim, hidden_dim)
233 self.target_soft_q_net2 = SoftQNetwork(state_dim, action_dim, hidden_dim)
234 self.policy_net = PolicyNetwork(state_dim, action_dim, hidden_dim, action_range)
235 self.soft_q_net1.train()
236 self.soft_q_net2.train()
237 self.target_soft_q_net1.eval()
238 self.target_soft_q_net2.eval()
239 self.policy_net.train()
240
241 self.log_alpha = tf.Variable(0, dtype=np.float32, name='log_alpha')
242 self.alpha = tf.math.exp(self.log_alpha)
243 print('Soft Q Network (1,2): ', self.soft_q_net1)
244 print('Policy Network: ', self.policy_net)
245 # set mode
246 self.soft_q_net1.train()
247 self.soft_q_net2.train()
248 self.target_soft_q_net1.eval()
249 self.target_soft_q_net2.eval()
250 self.policy_net.train()
251
252 # initialize weights of target networks
253 self.target_soft_q_net1 = self.target_ini(self.soft_q_net1, self.target_soft_q_net1)
254 self.target_soft_q_net2 = self.target_ini(self.soft_q_net2, self.target_soft_q_net2)
255
256 self.soft_q_optimizer1 = tf.optimizers.Adam(SOFT_Q_LR)
257 self.soft_q_optimizer2 = tf.optimizers.Adam(SOFT_Q_LR)
258 self.policy_optimizer = tf.optimizers.Adam(POLICY_LR)
259 self.alpha_optimizer = tf.optimizers.Adam(ALPHA_LR)
260
261 def target_ini(self, net, target_net):
262 """ hard-copy update for initializing target networks """

Callers 2

__init__Method · 0.45
__init__Method · 0.45

Calls 5

target_iniMethod · 0.95
SoftQNetworkClass · 0.85
evalMethod · 0.80
PolicyNetworkClass · 0.70
trainMethod · 0.45

Tested by

no test coverage detected