MCPcopy Index your code
hub / github.com/tensorlayer/TensorLayer / load_hdf5_to_weights

Function load_hdf5_to_weights

tensorlayer/files/utils.py:2736–2780  ·  view source on GitHub ↗

Load weights by name from a given file of hdf5 format Parameters ---------- filepath : str Filename to which the weights will be loaded, should be of hdf5 format. network : Model TL model. skip : bool If 'skip' == True, loaded weights whose name is not fo

(filepath, network, skip=False)

Source from the content-addressed store, hash-verified

2734
2735
2736def load_hdf5_to_weights(filepath, network, skip=False):
2737 """Load weights by name from a given file of hdf5 format
2738
2739 Parameters
2740 ----------
2741 filepath : str
2742 Filename to which the weights will be loaded, should be of hdf5 format.
2743 network : Model
2744 TL model.
2745 skip : bool
2746 If 'skip' == True, loaded weights whose name is not found in 'weights' will be skipped. If 'skip' is False,
2747 error will be raised when mismatch is found. Default False.
2748
2749 Returns
2750 -------
2751
2752 """
2753 f = h5py.File(filepath, 'r')
2754 try:
2755 layer_names = [n.decode('utf8') for n in f.attrs["layer_names"]]
2756 except Exception:
2757 raise NameError(
2758 "The loaded hdf5 file needs to have 'layer_names' as attributes. "
2759 "Please check whether this hdf5 file is saved from TL."
2760 )
2761
2762 net_index = {layer.name: layer for layer in network.all_layers}
2763
2764 if len(network.all_layers) != len(layer_names):
2765 logging.warning(
2766 "Number of weights mismatch."
2767 "Trying to load a saved file with " + str(len(layer_names)) + " layers into a model with " +
2768 str(len(network.all_layers)) + " layers."
2769 )
2770
2771 # check mismatch form network weights to hdf5
2772 for name in net_index.keys():
2773 if name not in layer_names:
2774 logging.warning("Network layer named '%s' not found in loaded hdf5 file. It will be skipped." % name)
2775
2776 # load weights from hdf5 to network
2777 _load_weights_from_hdf5_group(f, network.all_layers, skip)
2778
2779 f.close()
2780 logging.info("[*] Load %s SUCCESS!" % filepath)
2781
2782
2783def check_ckpt_file(model_dir):

Callers

nothing calls this directly

Calls 2

closeMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…