MCPcopy Index your code
hub / github.com/tensorlayer/TensorLayer / SAC

Class SAC

examples/reinforcement_learning/tutorial_SAC.py:221–355  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

219
220
221class SAC:
222
223 def __init__(
224 self, state_dim, action_dim, action_range, hidden_dim, replay_buffer, SOFT_Q_LR=3e-4, POLICY_LR=3e-4,
225 ALPHA_LR=3e-4
226 ):
227 self.replay_buffer = replay_buffer
228
229 # initialize all networks
230 self.soft_q_net1 = SoftQNetwork(state_dim, action_dim, hidden_dim)
231 self.soft_q_net2 = SoftQNetwork(state_dim, action_dim, hidden_dim)
232 self.target_soft_q_net1 = SoftQNetwork(state_dim, action_dim, hidden_dim)
233 self.target_soft_q_net2 = SoftQNetwork(state_dim, action_dim, hidden_dim)
234 self.policy_net = PolicyNetwork(state_dim, action_dim, hidden_dim, action_range)
235 self.soft_q_net1.train()
236 self.soft_q_net2.train()
237 self.target_soft_q_net1.eval()
238 self.target_soft_q_net2.eval()
239 self.policy_net.train()
240
241 self.log_alpha = tf.Variable(0, dtype=np.float32, name='log_alpha')
242 self.alpha = tf.math.exp(self.log_alpha)
243 print('Soft Q Network (1,2): ', self.soft_q_net1)
244 print('Policy Network: ', self.policy_net)
245 # set mode
246 self.soft_q_net1.train()
247 self.soft_q_net2.train()
248 self.target_soft_q_net1.eval()
249 self.target_soft_q_net2.eval()
250 self.policy_net.train()
251
252 # initialize weights of target networks
253 self.target_soft_q_net1 = self.target_ini(self.soft_q_net1, self.target_soft_q_net1)
254 self.target_soft_q_net2 = self.target_ini(self.soft_q_net2, self.target_soft_q_net2)
255
256 self.soft_q_optimizer1 = tf.optimizers.Adam(SOFT_Q_LR)
257 self.soft_q_optimizer2 = tf.optimizers.Adam(SOFT_Q_LR)
258 self.policy_optimizer = tf.optimizers.Adam(POLICY_LR)
259 self.alpha_optimizer = tf.optimizers.Adam(ALPHA_LR)
260
261 def target_ini(self, net, target_net):
262 """ hard-copy update for initializing target networks """
263 for target_param, param in zip(target_net.trainable_weights, net.trainable_weights):
264 target_param.assign(param)
265 return target_net
266
267 def target_soft_update(self, net, target_net, soft_tau):
268 """ soft update the target net with Polyak averaging """
269 for target_param, param in zip(target_net.trainable_weights, net.trainable_weights):
270 target_param.assign( # copy weight value into target parameters
271 target_param * (1.0 - soft_tau) + param * soft_tau
272 )
273 return target_net
274
275 def update(self, batch_size, reward_scale=10., auto_entropy=True, target_entropy=-2, gamma=0.99, soft_tau=1e-2):
276 """ update all networks in SAC """
277 state, action, reward, next_state, done = self.replay_buffer.sample(batch_size)
278

Callers 1

tutorial_SAC.pyFile · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected