Load weights by name from a given file of hdf5 format Parameters ---------- filepath : str Filename to which the weights will be loaded, should be of hdf5 format. network : Model TL model. skip : bool If 'skip' == True, loaded weights whose name is not fo
(filepath, network, skip=False)
| 2734 | |
| 2735 | |
| 2736 | def load_hdf5_to_weights(filepath, network, skip=False): |
| 2737 | """Load weights by name from a given file of hdf5 format |
| 2738 | |
| 2739 | Parameters |
| 2740 | ---------- |
| 2741 | filepath : str |
| 2742 | Filename to which the weights will be loaded, should be of hdf5 format. |
| 2743 | network : Model |
| 2744 | TL model. |
| 2745 | skip : bool |
| 2746 | If 'skip' == True, loaded weights whose name is not found in 'weights' will be skipped. If 'skip' is False, |
| 2747 | error will be raised when mismatch is found. Default False. |
| 2748 | |
| 2749 | Returns |
| 2750 | ------- |
| 2751 | |
| 2752 | """ |
| 2753 | f = h5py.File(filepath, 'r') |
| 2754 | try: |
| 2755 | layer_names = [n.decode('utf8') for n in f.attrs["layer_names"]] |
| 2756 | except Exception: |
| 2757 | raise NameError( |
| 2758 | "The loaded hdf5 file needs to have 'layer_names' as attributes. " |
| 2759 | "Please check whether this hdf5 file is saved from TL." |
| 2760 | ) |
| 2761 | |
| 2762 | net_index = {layer.name: layer for layer in network.all_layers} |
| 2763 | |
| 2764 | if len(network.all_layers) != len(layer_names): |
| 2765 | logging.warning( |
| 2766 | "Number of weights mismatch." |
| 2767 | "Trying to load a saved file with " + str(len(layer_names)) + " layers into a model with " + |
| 2768 | str(len(network.all_layers)) + " layers." |
| 2769 | ) |
| 2770 | |
| 2771 | # check mismatch form network weights to hdf5 |
| 2772 | for name in net_index.keys(): |
| 2773 | if name not in layer_names: |
| 2774 | logging.warning("Network layer named '%s' not found in loaded hdf5 file. It will be skipped." % name) |
| 2775 | |
| 2776 | # load weights from hdf5 to network |
| 2777 | _load_weights_from_hdf5_group(f, network.all_layers, skip) |
| 2778 | |
| 2779 | f.close() |
| 2780 | logging.info("[*] Load %s SUCCESS!" % filepath) |
| 2781 | |
| 2782 | |
| 2783 | def check_ckpt_file(model_dir): |
nothing calls this directly
no test coverage detected
searching dependent graphs…