Load weights by name from a given file of ckpt format Parameters ---------- model_dir : str Filename to which the weights will be loaded, should be of ckpt format. Examples: model_dir = /root/cnn_model/ network : Model TL model. skip : bool If 'sk
(model_dir, network=None, skip=True)
| 2821 | |
| 2822 | |
| 2823 | def load_and_assign_ckpt(model_dir, network=None, skip=True): |
| 2824 | """Load weights by name from a given file of ckpt format |
| 2825 | |
| 2826 | Parameters |
| 2827 | ---------- |
| 2828 | model_dir : str |
| 2829 | Filename to which the weights will be loaded, should be of ckpt format. |
| 2830 | Examples: model_dir = /root/cnn_model/ |
| 2831 | network : Model |
| 2832 | TL model. |
| 2833 | skip : bool |
| 2834 | If 'skip' == True, loaded weights whose name is not found in 'weights' will be skipped. If 'skip' is False, |
| 2835 | error will be raised when mismatch is found. Default False. |
| 2836 | |
| 2837 | Returns |
| 2838 | ------- |
| 2839 | |
| 2840 | """ |
| 2841 | model_path, filename = check_ckpt_file(model_dir) |
| 2842 | |
| 2843 | reader = pywrap_tensorflow.NewCheckpointReader(model_path) |
| 2844 | var_to_shape_map = reader.get_variable_to_shape_map() |
| 2845 | |
| 2846 | net_weights_name = [w.name for w in network.all_weights] |
| 2847 | |
| 2848 | for key in var_to_shape_map: |
| 2849 | if key not in net_weights_name: |
| 2850 | if skip: |
| 2851 | logging.warning("Weights named '%s' not found in network. Skip it." % key) |
| 2852 | else: |
| 2853 | raise RuntimeError( |
| 2854 | "Weights named '%s' not found in network. Hint: set argument skip=Ture " |
| 2855 | "if you want to skip redundant or mismatch weights." % key |
| 2856 | ) |
| 2857 | else: |
| 2858 | assign_tf_variable(network.all_weights[net_weights_name.index(key)], reader.get_tensor(key)) |
| 2859 | logging.info("[*] Model restored from ckpt %s" % filename) |
| 2860 | |
| 2861 | |
| 2862 | def ckpt_to_npz_dict(model_dir, save_name='model.npz', rename_key=False): |
nothing calls this directly
no test coverage detected
searching dependent graphs…