MCPcopy Index your code
hub / github.com/tensorlayer/TensorLayer / save_weights

Method save_weights

tensorlayer/models/core.py:838–897  ·  view source on GitHub ↗

Input filepath, save model weights into a file of given format. Use self.load_weights() to restore. Parameters ---------- filepath : str Filename to which the model weights will be saved. format : str or None Saved file format.

(self, filepath, format=None)

Source from the content-addressed store, hash-verified

836 return M
837
838 def save_weights(self, filepath, format=None):
839 """Input filepath, save model weights into a file of given format.
840 Use self.load_weights() to restore.
841
842 Parameters
843 ----------
844 filepath : str
845 Filename to which the model weights will be saved.
846 format : str or None
847 Saved file format.
848 Value should be None, 'hdf5', 'npz', 'npz_dict' or 'ckpt'. Other format is not supported now.
849 1) If this is set to None, then the postfix of filepath will be used to decide saved format.
850 If the postfix is not in ['h5', 'hdf5', 'npz', 'ckpt'], then file will be saved in hdf5 format by default.
851 2) 'hdf5' will save model weights name in a list and each layer has its weights stored in a group of
852 the hdf5 file.
853 3) 'npz' will save model weights sequentially into a npz file.
854 4) 'npz_dict' will save model weights along with its name as a dict into a npz file.
855 5) 'ckpt' will save model weights into a tensorflow ckpt file.
856
857 Default None.
858
859 Examples
860 --------
861 1) Save model weights in hdf5 format by default.
862 >>> net = tl.models.vgg16()
863 >>> net.save_weights('./model.h5')
864 ...
865 >>> net.load_weights('./model.h5')
866
867 2) Save model weights in npz/npz_dict format
868 >>> net = tl.models.vgg16()
869 >>> net.save_weights('./model.npz')
870 >>> net.save_weights('./model.npz', format='npz_dict')
871
872 """
873 if self.all_weights is None or len(self.all_weights) == 0:
874 logging.warning("Model contains no weights or layers haven't been built, nothing will be saved")
875 return
876
877 if format is None:
878 postfix = filepath.split('.')[-1]
879 if postfix in ['h5', 'hdf5', 'npz', 'ckpt']:
880 format = postfix
881 else:
882 format = 'hdf5'
883
884 if format == 'hdf5' or format == 'h5':
885 utils.save_weights_to_hdf5(filepath, self)
886 elif format == 'npz':
887 utils.save_npz(self.all_weights, filepath)
888 elif format == 'npz_dict':
889 utils.save_npz_dict(self.all_weights, filepath)
890 elif format == 'ckpt':
891 # TODO: enable this when tf save ckpt is enabled
892 raise NotImplementedError("ckpt load/save is not supported now.")
893 else:
894 raise ValueError(
895 "Save format must be 'hdf5', 'npz', 'npz_dict' or 'ckpt'."

Callers 15

test_exceptionsMethod · 0.95
test_layerlistMethod · 0.95
test_exceptionsMethod · 0.95
main_word2vec_basicFunction · 0.95
test_keras_save.pyFile · 0.80
normal_saveMethod · 0.80
test_normal_saveMethod · 0.80
test_skipMethod · 0.80
test_nested_vggMethod · 0.80

Calls

no outgoing calls

Tested by 9

test_exceptionsMethod · 0.76
test_layerlistMethod · 0.76
test_exceptionsMethod · 0.76
normal_saveMethod · 0.64
test_normal_saveMethod · 0.64
test_skipMethod · 0.64
test_nested_vggMethod · 0.64