MCPcopy Index your code
hub / github.com/tensorlayer/TensorLayer / normal_save

Method normal_save

tests/models/test_model_save.py:88–145  ·  view source on GitHub ↗
(self, model_basic)

Source from the content-addressed store, hash-verified

86 pass
87
88 def normal_save(self, model_basic):
89 # Default save
90 model_basic.save_weights('./model_basic.none')
91
92 # hdf5
93 print('testing hdf5 saving...')
94 modify_val = np.zeros_like(model_basic.all_weights[-2].numpy())
95 ori_val = model_basic.all_weights[-2].numpy()
96 model_basic.save_weights("./model_basic.h5")
97 model_basic.all_weights[-2].assign(modify_val)
98 model_basic.load_weights("./model_basic.h5")
99 self.assertLess(np.max(np.abs(ori_val - model_basic.all_weights[-2].numpy())), 1e-7)
100
101 model_basic.all_weights[-2].assign(modify_val)
102 model_basic.load_weights("./model_basic.h5", format="hdf5")
103 self.assertLess(np.max(np.abs(ori_val - model_basic.all_weights[-2].numpy())), 1e-7)
104
105 model_basic.all_weights[-2].assign(modify_val)
106 model_basic.load_weights("./model_basic.h5", format="hdf5", in_order=False)
107 self.assertLess(np.max(np.abs(ori_val - model_basic.all_weights[-2].numpy())), 1e-7)
108
109 # npz
110 print('testing npz saving...')
111 model_basic.save_weights("./model_basic.npz", format='npz')
112 model_basic.all_weights[-2].assign(modify_val)
113 model_basic.load_weights("./model_basic.npz")
114
115 model_basic.all_weights[-2].assign(modify_val)
116 model_basic.load_weights("./model_basic.npz", format='npz')
117 model_basic.save_weights("./model_basic.npz")
118 self.assertLess(np.max(np.abs(ori_val - model_basic.all_weights[-2].numpy())), 1e-7)
119
120 # npz_dict
121 print('testing npz_dict saving...')
122 model_basic.save_weights("./model_basic.npz", format='npz_dict')
123 model_basic.all_weights[-2].assign(modify_val)
124 model_basic.load_weights("./model_basic.npz", format='npz_dict')
125 self.assertLess(np.max(np.abs(ori_val - model_basic.all_weights[-2].numpy())), 1e-7)
126
127 # ckpt
128 try:
129 model_basic.save_weights('./model_basic.ckpt', format='ckpt')
130 except Exception as e:
131 self.assertIsInstance(e, NotImplementedError)
132
133 # other cases
134 try:
135 model_basic.save_weights('./model_basic.xyz', format='xyz')
136 except Exception as e:
137 self.assertIsInstance(e, ValueError)
138 try:
139 model_basic.load_weights('./model_basic.xyz', format='xyz')
140 except Exception as e:
141 self.assertIsInstance(e, FileNotFoundError)
142 try:
143 model_basic.load_weights('./model_basic.h5', format='xyz')
144 except Exception as e:
145 self.assertIsInstance(e, ValueError)

Callers 1

test_normal_saveMethod · 0.95

Calls 2

save_weightsMethod · 0.80
load_weightsMethod · 0.45

Tested by

no test coverage detected