| 326 | tl.files.save_npz(self.target_policy_net.trainable_weights, extend_path('model_target_policy_net.npz')) |
| 327 | |
| 328 | def load(self): # load trained weights |
| 329 | path = os.path.join('model', '_'.join([ALG_NAME, ENV_ID])) |
| 330 | extend_path = lambda s: os.path.join(path, s) |
| 331 | tl.files.load_and_assign_npz(extend_path('model_q_net1.npz'), self.q_net1) |
| 332 | tl.files.load_and_assign_npz(extend_path('model_q_net2.npz'), self.q_net2) |
| 333 | tl.files.load_and_assign_npz(extend_path('model_target_q_net1.npz'), self.target_q_net1) |
| 334 | tl.files.load_and_assign_npz(extend_path('model_target_q_net2.npz'), self.target_q_net2) |
| 335 | tl.files.load_and_assign_npz(extend_path('model_policy_net.npz'), self.policy_net) |
| 336 | tl.files.load_and_assign_npz(extend_path('model_target_policy_net.npz'), self.target_policy_net) |
| 337 | |
| 338 | |
| 339 | if __name__ == '__main__': |