Input filepath, save model weights into a file of given format. Use self.load_weights() to restore. Parameters ---------- filepath : str Filename to which the model weights will be saved. format : str or None Saved file format.
(self, filepath, format=None)
| 836 | return M |
| 837 | |
| 838 | def save_weights(self, filepath, format=None): |
| 839 | """Input filepath, save model weights into a file of given format. |
| 840 | Use self.load_weights() to restore. |
| 841 | |
| 842 | Parameters |
| 843 | ---------- |
| 844 | filepath : str |
| 845 | Filename to which the model weights will be saved. |
| 846 | format : str or None |
| 847 | Saved file format. |
| 848 | Value should be None, 'hdf5', 'npz', 'npz_dict' or 'ckpt'. Other format is not supported now. |
| 849 | 1) If this is set to None, then the postfix of filepath will be used to decide saved format. |
| 850 | If the postfix is not in ['h5', 'hdf5', 'npz', 'ckpt'], then file will be saved in hdf5 format by default. |
| 851 | 2) 'hdf5' will save model weights name in a list and each layer has its weights stored in a group of |
| 852 | the hdf5 file. |
| 853 | 3) 'npz' will save model weights sequentially into a npz file. |
| 854 | 4) 'npz_dict' will save model weights along with its name as a dict into a npz file. |
| 855 | 5) 'ckpt' will save model weights into a tensorflow ckpt file. |
| 856 | |
| 857 | Default None. |
| 858 | |
| 859 | Examples |
| 860 | -------- |
| 861 | 1) Save model weights in hdf5 format by default. |
| 862 | >>> net = tl.models.vgg16() |
| 863 | >>> net.save_weights('./model.h5') |
| 864 | ... |
| 865 | >>> net.load_weights('./model.h5') |
| 866 | |
| 867 | 2) Save model weights in npz/npz_dict format |
| 868 | >>> net = tl.models.vgg16() |
| 869 | >>> net.save_weights('./model.npz') |
| 870 | >>> net.save_weights('./model.npz', format='npz_dict') |
| 871 | |
| 872 | """ |
| 873 | if self.all_weights is None or len(self.all_weights) == 0: |
| 874 | logging.warning("Model contains no weights or layers haven't been built, nothing will be saved") |
| 875 | return |
| 876 | |
| 877 | if format is None: |
| 878 | postfix = filepath.split('.')[-1] |
| 879 | if postfix in ['h5', 'hdf5', 'npz', 'ckpt']: |
| 880 | format = postfix |
| 881 | else: |
| 882 | format = 'hdf5' |
| 883 | |
| 884 | if format == 'hdf5' or format == 'h5': |
| 885 | utils.save_weights_to_hdf5(filepath, self) |
| 886 | elif format == 'npz': |
| 887 | utils.save_npz(self.all_weights, filepath) |
| 888 | elif format == 'npz_dict': |
| 889 | utils.save_npz_dict(self.all_weights, filepath) |
| 890 | elif format == 'ckpt': |
| 891 | # TODO: enable this when tf save ckpt is enabled |
| 892 | raise NotImplementedError("ckpt load/save is not supported now.") |
| 893 | else: |
| 894 | raise ValueError( |
| 895 | "Save format must be 'hdf5', 'npz', 'npz_dict' or 'ckpt'." |
no outgoing calls