| 333 | self.target_soft_q_net2 = self.target_soft_update(self.soft_q_net2, self.target_soft_q_net2, soft_tau) |
| 334 | |
| 335 | def save(self): # save trained weights |
| 336 | path = os.path.join('model', '_'.join([ALG_NAME, ENV_ID])) |
| 337 | if not os.path.exists(path): |
| 338 | os.makedirs(path) |
| 339 | extend_path = lambda s: os.path.join(path, s) |
| 340 | tl.files.save_npz(self.soft_q_net1.trainable_weights, extend_path('model_q_net1.npz')) |
| 341 | tl.files.save_npz(self.soft_q_net2.trainable_weights, extend_path('model_q_net2.npz')) |
| 342 | tl.files.save_npz(self.target_soft_q_net1.trainable_weights, extend_path('model_target_q_net1.npz')) |
| 343 | tl.files.save_npz(self.target_soft_q_net2.trainable_weights, extend_path('model_target_q_net2.npz')) |
| 344 | tl.files.save_npz(self.policy_net.trainable_weights, extend_path('model_policy_net.npz')) |
| 345 | np.save(extend_path('log_alpha.npy'), self.log_alpha.numpy()) # save log_alpha variable |
| 346 | |
| 347 | def load_weights(self): # load trained weights |
| 348 | path = os.path.join('model', '_'.join([ALG_NAME, ENV_ID])) |