(
self, state_dim, action_dim, action_range, hidden_dim, replay_buffer, SOFT_Q_LR=3e-4, POLICY_LR=3e-4,
ALPHA_LR=3e-4
)
| 221 | class SAC: |
| 222 | |
| 223 | def __init__( |
| 224 | self, state_dim, action_dim, action_range, hidden_dim, replay_buffer, SOFT_Q_LR=3e-4, POLICY_LR=3e-4, |
| 225 | ALPHA_LR=3e-4 |
| 226 | ): |
| 227 | self.replay_buffer = replay_buffer |
| 228 | |
| 229 | # initialize all networks |
| 230 | self.soft_q_net1 = SoftQNetwork(state_dim, action_dim, hidden_dim) |
| 231 | self.soft_q_net2 = SoftQNetwork(state_dim, action_dim, hidden_dim) |
| 232 | self.target_soft_q_net1 = SoftQNetwork(state_dim, action_dim, hidden_dim) |
| 233 | self.target_soft_q_net2 = SoftQNetwork(state_dim, action_dim, hidden_dim) |
| 234 | self.policy_net = PolicyNetwork(state_dim, action_dim, hidden_dim, action_range) |
| 235 | self.soft_q_net1.train() |
| 236 | self.soft_q_net2.train() |
| 237 | self.target_soft_q_net1.eval() |
| 238 | self.target_soft_q_net2.eval() |
| 239 | self.policy_net.train() |
| 240 | |
| 241 | self.log_alpha = tf.Variable(0, dtype=np.float32, name='log_alpha') |
| 242 | self.alpha = tf.math.exp(self.log_alpha) |
| 243 | print('Soft Q Network (1,2): ', self.soft_q_net1) |
| 244 | print('Policy Network: ', self.policy_net) |
| 245 | # set mode |
| 246 | self.soft_q_net1.train() |
| 247 | self.soft_q_net2.train() |
| 248 | self.target_soft_q_net1.eval() |
| 249 | self.target_soft_q_net2.eval() |
| 250 | self.policy_net.train() |
| 251 | |
| 252 | # initialize weights of target networks |
| 253 | self.target_soft_q_net1 = self.target_ini(self.soft_q_net1, self.target_soft_q_net1) |
| 254 | self.target_soft_q_net2 = self.target_ini(self.soft_q_net2, self.target_soft_q_net2) |
| 255 | |
| 256 | self.soft_q_optimizer1 = tf.optimizers.Adam(SOFT_Q_LR) |
| 257 | self.soft_q_optimizer2 = tf.optimizers.Adam(SOFT_Q_LR) |
| 258 | self.policy_optimizer = tf.optimizers.Adam(POLICY_LR) |
| 259 | self.alpha_optimizer = tf.optimizers.Adam(ALPHA_LR) |
| 260 | |
| 261 | def target_ini(self, net, target_net): |
| 262 | """ hard-copy update for initializing target networks """ |
no test coverage detected