| 208 | |
| 209 | |
| 210 | class TD3: |
| 211 | |
| 212 | def __init__( |
| 213 | self, state_dim, action_dim, action_range, hidden_dim, replay_buffer, policy_target_update_interval=1, |
| 214 | q_lr=3e-4, policy_lr=3e-4 |
| 215 | ): |
| 216 | self.replay_buffer = replay_buffer |
| 217 | |
| 218 | # initialize all networks |
| 219 | self.q_net1 = QNetwork(state_dim, action_dim, hidden_dim) |
| 220 | self.q_net2 = QNetwork(state_dim, action_dim, hidden_dim) |
| 221 | self.target_q_net1 = QNetwork(state_dim, action_dim, hidden_dim) |
| 222 | self.target_q_net2 = QNetwork(state_dim, action_dim, hidden_dim) |
| 223 | self.policy_net = PolicyNetwork(state_dim, action_dim, hidden_dim, action_range) |
| 224 | self.target_policy_net = PolicyNetwork(state_dim, action_dim, hidden_dim, action_range) |
| 225 | print('Q Network (1,2): ', self.q_net1) |
| 226 | print('Policy Network: ', self.policy_net) |
| 227 | |
| 228 | # initialize weights of target networks |
| 229 | self.target_q_net1 = self.target_ini(self.q_net1, self.target_q_net1) |
| 230 | self.target_q_net2 = self.target_ini(self.q_net2, self.target_q_net2) |
| 231 | self.target_policy_net = self.target_ini(self.policy_net, self.target_policy_net) |
| 232 | |
| 233 | # set train mode |
| 234 | self.q_net1.train() |
| 235 | self.q_net2.train() |
| 236 | self.target_q_net1.eval() |
| 237 | self.target_q_net2.eval() |
| 238 | self.policy_net.train() |
| 239 | self.target_policy_net.eval() |
| 240 | |
| 241 | self.update_cnt = 0 |
| 242 | self.policy_target_update_interval = policy_target_update_interval |
| 243 | |
| 244 | self.q_optimizer1 = tf.optimizers.Adam(q_lr) |
| 245 | self.q_optimizer2 = tf.optimizers.Adam(q_lr) |
| 246 | self.policy_optimizer = tf.optimizers.Adam(policy_lr) |
| 247 | |
| 248 | def target_ini(self, net, target_net): |
| 249 | """ hard-copy update for initializing target networks """ |
| 250 | for target_param, param in zip(target_net.trainable_weights, net.trainable_weights): |
| 251 | target_param.assign(param) |
| 252 | return target_net |
| 253 | |
| 254 | def target_soft_update(self, net, target_net, soft_tau): |
| 255 | """ soft update the target net with Polyak averaging """ |
| 256 | for target_param, param in zip(target_net.trainable_weights, net.trainable_weights): |
| 257 | target_param.assign( # copy weight value into target parameters |
| 258 | target_param * (1.0 - soft_tau) + param * soft_tau |
| 259 | ) |
| 260 | return target_net |
| 261 | |
| 262 | def update(self, batch_size, eval_noise_scale, reward_scale=10., gamma=0.9, soft_tau=1e-2): |
| 263 | """ update all networks in TD3 """ |
| 264 | self.update_cnt += 1 |
| 265 | state, action, reward, next_state, done = self.replay_buffer.sample(batch_size) |
| 266 | |
| 267 | reward = reward[:, np.newaxis] # expand dim |