| 219 | |
| 220 | |
| 221 | class SAC: |
| 222 | |
| 223 | def __init__( |
| 224 | self, state_dim, action_dim, action_range, hidden_dim, replay_buffer, SOFT_Q_LR=3e-4, POLICY_LR=3e-4, |
| 225 | ALPHA_LR=3e-4 |
| 226 | ): |
| 227 | self.replay_buffer = replay_buffer |
| 228 | |
| 229 | # initialize all networks |
| 230 | self.soft_q_net1 = SoftQNetwork(state_dim, action_dim, hidden_dim) |
| 231 | self.soft_q_net2 = SoftQNetwork(state_dim, action_dim, hidden_dim) |
| 232 | self.target_soft_q_net1 = SoftQNetwork(state_dim, action_dim, hidden_dim) |
| 233 | self.target_soft_q_net2 = SoftQNetwork(state_dim, action_dim, hidden_dim) |
| 234 | self.policy_net = PolicyNetwork(state_dim, action_dim, hidden_dim, action_range) |
| 235 | self.soft_q_net1.train() |
| 236 | self.soft_q_net2.train() |
| 237 | self.target_soft_q_net1.eval() |
| 238 | self.target_soft_q_net2.eval() |
| 239 | self.policy_net.train() |
| 240 | |
| 241 | self.log_alpha = tf.Variable(0, dtype=np.float32, name='log_alpha') |
| 242 | self.alpha = tf.math.exp(self.log_alpha) |
| 243 | print('Soft Q Network (1,2): ', self.soft_q_net1) |
| 244 | print('Policy Network: ', self.policy_net) |
| 245 | # set mode |
| 246 | self.soft_q_net1.train() |
| 247 | self.soft_q_net2.train() |
| 248 | self.target_soft_q_net1.eval() |
| 249 | self.target_soft_q_net2.eval() |
| 250 | self.policy_net.train() |
| 251 | |
| 252 | # initialize weights of target networks |
| 253 | self.target_soft_q_net1 = self.target_ini(self.soft_q_net1, self.target_soft_q_net1) |
| 254 | self.target_soft_q_net2 = self.target_ini(self.soft_q_net2, self.target_soft_q_net2) |
| 255 | |
| 256 | self.soft_q_optimizer1 = tf.optimizers.Adam(SOFT_Q_LR) |
| 257 | self.soft_q_optimizer2 = tf.optimizers.Adam(SOFT_Q_LR) |
| 258 | self.policy_optimizer = tf.optimizers.Adam(POLICY_LR) |
| 259 | self.alpha_optimizer = tf.optimizers.Adam(ALPHA_LR) |
| 260 | |
| 261 | def target_ini(self, net, target_net): |
| 262 | """ hard-copy update for initializing target networks """ |
| 263 | for target_param, param in zip(target_net.trainable_weights, net.trainable_weights): |
| 264 | target_param.assign(param) |
| 265 | return target_net |
| 266 | |
| 267 | def target_soft_update(self, net, target_net, soft_tau): |
| 268 | """ soft update the target net with Polyak averaging """ |
| 269 | for target_param, param in zip(target_net.trainable_weights, net.trainable_weights): |
| 270 | target_param.assign( # copy weight value into target parameters |
| 271 | target_param * (1.0 - soft_tau) + param * soft_tau |
| 272 | ) |
| 273 | return target_net |
| 274 | |
| 275 | def update(self, batch_size, reward_scale=10., auto_entropy=True, target_entropy=-2, gamma=0.99, soft_tau=1e-2): |
| 276 | """ update all networks in SAC """ |
| 277 | state, action, reward, next_state, done = self.replay_buffer.sample(batch_size) |
| 278 | |