MCPcopy Index your code
hub / github.com/tensorlayer/TensorLayer / test_skip

Method test_skip

tests/models/test_model_save.py:160–189  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

158 print(e)
159
160 def test_skip(self):
161 print('-' * 20, 'test skip save/load', '-' * 20)
162
163 print("testing dynamic skip load...")
164 self.dynamic_basic.save_weights("./model_basic.h5")
165 ori_weights = self.dynamic_basic_skip.all_weights
166 ori_val = ori_weights[1].numpy()
167 modify_val = np.zeros_like(ori_val)
168 self.dynamic_basic_skip.all_weights[1].assign(modify_val)
169 self.dynamic_basic_skip.load_weights("./model_basic.h5", skip=True)
170 self.assertLess(np.max(np.abs(ori_val - self.dynamic_basic_skip.all_weights[1].numpy())), 1e-7)
171
172 try:
173 self.dynamic_basic_skip.load_weights("./model_basic.h5", in_order=False, skip=False)
174 except Exception as e:
175 print(e)
176
177 print("testing static skip load...")
178 self.static_basic.save_weights("./model_basic.h5")
179 ori_weights = self.static_basic_skip.all_weights
180 ori_val = ori_weights[1].numpy()
181 modify_val = np.zeros_like(ori_val)
182 self.static_basic_skip.all_weights[1].assign(modify_val)
183 self.static_basic_skip.load_weights("./model_basic.h5", skip=True)
184 self.assertLess(np.max(np.abs(ori_val - self.static_basic_skip.all_weights[1].numpy())), 1e-7)
185
186 try:
187 self.static_basic_skip.load_weights("./model_basic.h5", in_order=False, skip=False)
188 except Exception as e:
189 print(e)
190
191 def test_nested_vgg(self):
192 print('-' * 20, 'test nested vgg', '-' * 20)

Callers

nothing calls this directly

Calls 2

save_weightsMethod · 0.80
load_weightsMethod · 0.45

Tested by

no test coverage detected