MCPcopy Index your code
hub / github.com/tensorlayer/TensorLayer / load_weights

Method load_weights

tensorlayer/models/core.py:899–970  ·  view source on GitHub ↗

Load model weights from a given file, which should be previously saved by self.save_weights(). Parameters ---------- filepath : str Filename from which the model weights will be loaded. format : str or None If not specified (None), the postfix

(self, filepath, format=None, in_order=True, skip=False)

Source from the content-addressed store, hash-verified

897 )
898
899 def load_weights(self, filepath, format=None, in_order=True, skip=False):
900 """Load model weights from a given file, which should be previously saved by self.save_weights().
901
902 Parameters
903 ----------
904 filepath : str
905 Filename from which the model weights will be loaded.
906 format : str or None
907 If not specified (None), the postfix of the filepath will be used to decide its format. If specified,
908 value should be 'hdf5', 'npz', 'npz_dict' or 'ckpt'. Other format is not supported now.
909 In addition, it should be the same format when you saved the file using self.save_weights().
910 Default is None.
911 in_order : bool
912 Allow loading weights into model in a sequential way or by name. Only useful when 'format' is 'hdf5'.
913 If 'in_order' is True, weights from the file will be loaded into model in a sequential way.
914 If 'in_order' is False, weights from the file will be loaded into model by matching the name
915 with the weights of the model, particularly useful when trying to restore model in eager(graph) mode from
916 a weights file which is saved in graph(eager) mode.
917 Default is True.
918 skip : bool
919 Allow skipping weights whose name is mismatched between the file and model. Only useful when 'format' is
920 'hdf5' or 'npz_dict'. If 'skip' is True, 'in_order' argument will be ignored and those loaded weights
921 whose name is not found in model weights (self.all_weights) will be skipped. If 'skip' is False, error will
922 occur when mismatch is found.
923 Default is False.
924
925 Examples
926 --------
927 1) load model from a hdf5 file.
928 >>> net = tl.models.vgg16()
929 >>> net.load_weights('./model_graph.h5', in_order=False, skip=True) # load weights by name, skipping mismatch
930 >>> net.load_weights('./model_eager.h5') # load sequentially
931
932 2) load model from a npz file
933 >>> net.load_weights('./model.npz')
934
935 2) load model from a npz file, which is saved as npz_dict previously
936 >>> net.load_weights('./model.npz', format='npz_dict')
937
938 Notes
939 -------
940 1) 'in_order' is only useful when 'format' is 'hdf5'. If you are trying to load a weights file which is
941 saved in a different mode, it is recommended to set 'in_order' be True.
942 2) 'skip' is useful when 'format' is 'hdf5' or 'npz_dict'. If 'skip' is True,
943 'in_order' argument will be ignored.
944
945 """
946 if not os.path.exists(filepath):
947 raise FileNotFoundError("file {} doesn't exist.".format(filepath))
948
949 if format is None:
950 format = filepath.split('.')[-1]
951
952 if format == 'hdf5' or format == 'h5':
953 if skip ==True or in_order == False:
954 # load by weights name
955 utils.load_hdf5_to_weights(filepath, self, skip)
956 else:

Callers 9

test_layerlistMethod · 0.95
main_word2vec_basicFunction · 0.95
load_hdf5_graphFunction · 0.45
normal_saveMethod · 0.45
test_normal_saveMethod · 0.45
test_skipMethod · 0.45
test_nested_vggMethod · 0.45

Calls

no outgoing calls

Tested by 7

test_layerlistMethod · 0.76
normal_saveMethod · 0.36
test_normal_saveMethod · 0.36
test_skipMethod · 0.36
test_nested_vggMethod · 0.36