| 314 | self.target_policy_net = self.target_soft_update(self.policy_net, self.target_policy_net, soft_tau) |
| 315 | |
| 316 | def save(self): # save trained weights |
| 317 | path = os.path.join('model', '_'.join([ALG_NAME, ENV_ID])) |
| 318 | if not os.path.exists(path): |
| 319 | os.makedirs(path) |
| 320 | extend_path = lambda s: os.path.join(path, s) |
| 321 | tl.files.save_npz(self.q_net1.trainable_weights, extend_path('model_q_net1.npz')) |
| 322 | tl.files.save_npz(self.q_net2.trainable_weights, extend_path('model_q_net2.npz')) |
| 323 | tl.files.save_npz(self.target_q_net1.trainable_weights, extend_path('model_target_q_net1.npz')) |
| 324 | tl.files.save_npz(self.target_q_net2.trainable_weights, extend_path('model_target_q_net2.npz')) |
| 325 | tl.files.save_npz(self.policy_net.trainable_weights, extend_path('model_policy_net.npz')) |
| 326 | tl.files.save_npz(self.target_policy_net.trainable_weights, extend_path('model_target_policy_net.npz')) |
| 327 | |
| 328 | def load(self): # load trained weights |
| 329 | path = os.path.join('model', '_'.join([ALG_NAME, ENV_ID])) |