(self)
| 345 | np.save(extend_path('log_alpha.npy'), self.log_alpha.numpy()) # save log_alpha variable |
| 346 | |
| 347 | def load_weights(self): # load trained weights |
| 348 | path = os.path.join('model', '_'.join([ALG_NAME, ENV_ID])) |
| 349 | extend_path = lambda s: os.path.join(path, s) |
| 350 | tl.files.load_and_assign_npz(extend_path('model_q_net1.npz'), self.soft_q_net1) |
| 351 | tl.files.load_and_assign_npz(extend_path('model_q_net2.npz'), self.soft_q_net2) |
| 352 | tl.files.load_and_assign_npz(extend_path('model_target_q_net1.npz'), self.target_soft_q_net1) |
| 353 | tl.files.load_and_assign_npz(extend_path('model_target_q_net2.npz'), self.target_soft_q_net2) |
| 354 | tl.files.load_and_assign_npz(extend_path('model_policy_net.npz'), self.policy_net) |
| 355 | self.log_alpha.assign(np.load(extend_path('log_alpha.npy'))) # load log_alpha variable |
| 356 | |
| 357 | |
| 358 | if __name__ == '__main__': |
no test coverage detected