Load model weights from a given file, which should be previously saved by self.save_weights(). Parameters ---------- filepath : str Filename from which the model weights will be loaded. format : str or None If not specified (None), the postfix
(self, filepath, format=None, in_order=True, skip=False)
| 897 | ) |
| 898 | |
| 899 | def load_weights(self, filepath, format=None, in_order=True, skip=False): |
| 900 | """Load model weights from a given file, which should be previously saved by self.save_weights(). |
| 901 | |
| 902 | Parameters |
| 903 | ---------- |
| 904 | filepath : str |
| 905 | Filename from which the model weights will be loaded. |
| 906 | format : str or None |
| 907 | If not specified (None), the postfix of the filepath will be used to decide its format. If specified, |
| 908 | value should be 'hdf5', 'npz', 'npz_dict' or 'ckpt'. Other format is not supported now. |
| 909 | In addition, it should be the same format when you saved the file using self.save_weights(). |
| 910 | Default is None. |
| 911 | in_order : bool |
| 912 | Allow loading weights into model in a sequential way or by name. Only useful when 'format' is 'hdf5'. |
| 913 | If 'in_order' is True, weights from the file will be loaded into model in a sequential way. |
| 914 | If 'in_order' is False, weights from the file will be loaded into model by matching the name |
| 915 | with the weights of the model, particularly useful when trying to restore model in eager(graph) mode from |
| 916 | a weights file which is saved in graph(eager) mode. |
| 917 | Default is True. |
| 918 | skip : bool |
| 919 | Allow skipping weights whose name is mismatched between the file and model. Only useful when 'format' is |
| 920 | 'hdf5' or 'npz_dict'. If 'skip' is True, 'in_order' argument will be ignored and those loaded weights |
| 921 | whose name is not found in model weights (self.all_weights) will be skipped. If 'skip' is False, error will |
| 922 | occur when mismatch is found. |
| 923 | Default is False. |
| 924 | |
| 925 | Examples |
| 926 | -------- |
| 927 | 1) load model from a hdf5 file. |
| 928 | >>> net = tl.models.vgg16() |
| 929 | >>> net.load_weights('./model_graph.h5', in_order=False, skip=True) # load weights by name, skipping mismatch |
| 930 | >>> net.load_weights('./model_eager.h5') # load sequentially |
| 931 | |
| 932 | 2) load model from a npz file |
| 933 | >>> net.load_weights('./model.npz') |
| 934 | |
| 935 | 2) load model from a npz file, which is saved as npz_dict previously |
| 936 | >>> net.load_weights('./model.npz', format='npz_dict') |
| 937 | |
| 938 | Notes |
| 939 | ------- |
| 940 | 1) 'in_order' is only useful when 'format' is 'hdf5'. If you are trying to load a weights file which is |
| 941 | saved in a different mode, it is recommended to set 'in_order' be True. |
| 942 | 2) 'skip' is useful when 'format' is 'hdf5' or 'npz_dict'. If 'skip' is True, |
| 943 | 'in_order' argument will be ignored. |
| 944 | |
| 945 | """ |
| 946 | if not os.path.exists(filepath): |
| 947 | raise FileNotFoundError("file {} doesn't exist.".format(filepath)) |
| 948 | |
| 949 | if format is None: |
| 950 | format = filepath.split('.')[-1] |
| 951 | |
| 952 | if format == 'hdf5' or format == 'h5': |
| 953 | if skip ==True or in_order == False: |
| 954 | # load by weights name |
| 955 | utils.load_hdf5_to_weights(filepath, self, skip) |
| 956 | else: |
no outgoing calls