MCPcopy
hub / github.com/tensorlayer/TensorLayer / load

Method load

examples/reinforcement_learning/tutorial_TD3.py:328–336  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

326 tl.files.save_npz(self.target_policy_net.trainable_weights, extend_path('model_target_policy_net.npz'))
327
328 def load(self): # load trained weights
329 path = os.path.join('model', '_'.join([ALG_NAME, ENV_ID]))
330 extend_path = lambda s: os.path.join(path, s)
331 tl.files.load_and_assign_npz(extend_path('model_q_net1.npz'), self.q_net1)
332 tl.files.load_and_assign_npz(extend_path('model_q_net2.npz'), self.q_net2)
333 tl.files.load_and_assign_npz(extend_path('model_target_q_net1.npz'), self.target_q_net1)
334 tl.files.load_and_assign_npz(extend_path('model_target_q_net2.npz'), self.target_q_net2)
335 tl.files.load_and_assign_npz(extend_path('model_policy_net.npz'), self.policy_net)
336 tl.files.load_and_assign_npz(extend_path('model_target_policy_net.npz'), self.target_policy_net)
337
338
339if __name__ == '__main__':

Callers 1

tutorial_TD3.pyFile · 0.45

Calls

no outgoing calls

Tested by

no test coverage detected