MCPcopy Index your code
hub / github.com/tensorlayer/TensorLayer / load_weights

Method load_weights

examples/reinforcement_learning/tutorial_SAC.py:347–355  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

345 np.save(extend_path('log_alpha.npy'), self.log_alpha.numpy()) # save log_alpha variable
346
347 def load_weights(self): # load trained weights
348 path = os.path.join('model', '_'.join([ALG_NAME, ENV_ID]))
349 extend_path = lambda s: os.path.join(path, s)
350 tl.files.load_and_assign_npz(extend_path('model_q_net1.npz'), self.soft_q_net1)
351 tl.files.load_and_assign_npz(extend_path('model_q_net2.npz'), self.soft_q_net2)
352 tl.files.load_and_assign_npz(extend_path('model_target_q_net1.npz'), self.target_soft_q_net1)
353 tl.files.load_and_assign_npz(extend_path('model_target_q_net2.npz'), self.target_soft_q_net2)
354 tl.files.load_and_assign_npz(extend_path('model_policy_net.npz'), self.policy_net)
355 self.log_alpha.assign(np.load(extend_path('log_alpha.npy'))) # load log_alpha variable
356
357
358if __name__ == '__main__':

Callers 4

tutorial_SAC.pyFile · 0.45
tutorial_format.pyFile · 0.45

Calls 1

loadMethod · 0.45

Tested by

no test coverage detected