MCPcopy Index your code
hub / github.com/tensorlayer/TensorLayer / load_and_assign_ckpt

Function load_and_assign_ckpt

tensorlayer/files/utils.py:2823–2859  ·  view source on GitHub ↗

Load weights by name from a given file of ckpt format Parameters ---------- model_dir : str Filename to which the weights will be loaded, should be of ckpt format. Examples: model_dir = /root/cnn_model/ network : Model TL model. skip : bool If 'sk

(model_dir, network=None, skip=True)

Source from the content-addressed store, hash-verified

2821
2822
2823def load_and_assign_ckpt(model_dir, network=None, skip=True):
2824 """Load weights by name from a given file of ckpt format
2825
2826 Parameters
2827 ----------
2828 model_dir : str
2829 Filename to which the weights will be loaded, should be of ckpt format.
2830 Examples: model_dir = /root/cnn_model/
2831 network : Model
2832 TL model.
2833 skip : bool
2834 If 'skip' == True, loaded weights whose name is not found in 'weights' will be skipped. If 'skip' is False,
2835 error will be raised when mismatch is found. Default False.
2836
2837 Returns
2838 -------
2839
2840 """
2841 model_path, filename = check_ckpt_file(model_dir)
2842
2843 reader = pywrap_tensorflow.NewCheckpointReader(model_path)
2844 var_to_shape_map = reader.get_variable_to_shape_map()
2845
2846 net_weights_name = [w.name for w in network.all_weights]
2847
2848 for key in var_to_shape_map:
2849 if key not in net_weights_name:
2850 if skip:
2851 logging.warning("Weights named '%s' not found in network. Skip it." % key)
2852 else:
2853 raise RuntimeError(
2854 "Weights named '%s' not found in network. Hint: set argument skip=Ture "
2855 "if you want to skip redundant or mismatch weights." % key
2856 )
2857 else:
2858 assign_tf_variable(network.all_weights[net_weights_name.index(key)], reader.get_tensor(key))
2859 logging.info("[*] Model restored from ckpt %s" % filename)
2860
2861
2862def ckpt_to_npz_dict(model_dir, save_name='model.npz', rename_key=False):

Callers

nothing calls this directly

Calls 2

check_ckpt_fileFunction · 0.85
assign_tf_variableFunction · 0.85

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…