MCPcopy
hub / github.com/tensorlayer/TensorLayer / save

Method save

examples/reinforcement_learning/tutorial_SAC.py:335–345  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

333 self.target_soft_q_net2 = self.target_soft_update(self.soft_q_net2, self.target_soft_q_net2, soft_tau)
334
335 def save(self): # save trained weights
336 path = os.path.join('model', '_'.join([ALG_NAME, ENV_ID]))
337 if not os.path.exists(path):
338 os.makedirs(path)
339 extend_path = lambda s: os.path.join(path, s)
340 tl.files.save_npz(self.soft_q_net1.trainable_weights, extend_path('model_q_net1.npz'))
341 tl.files.save_npz(self.soft_q_net2.trainable_weights, extend_path('model_q_net2.npz'))
342 tl.files.save_npz(self.target_soft_q_net1.trainable_weights, extend_path('model_target_q_net1.npz'))
343 tl.files.save_npz(self.target_soft_q_net2.trainable_weights, extend_path('model_target_q_net2.npz'))
344 tl.files.save_npz(self.policy_net.trainable_weights, extend_path('model_policy_net.npz'))
345 np.save(extend_path('log_alpha.npy'), self.log_alpha.numpy()) # save log_alpha variable
346
347 def load_weights(self): # load trained weights
348 path = os.path.join('model', '_'.join([ALG_NAME, ENV_ID]))

Callers 1

tutorial_SAC.pyFile · 0.45

Calls

no outgoing calls

Tested by

no test coverage detected