Load layer weights from a hdf5 group by layer name. Parameters ---------- f: hdf5 group A hdf5 group created by h5py.File() or create_group(). layers: list A list of layers to load weights. skip : boolean If 'skip' == True, loaded layer whose name is
(f, layers, skip=False)
| 2624 | |
| 2625 | |
| 2626 | def _load_weights_from_hdf5_group(f, layers, skip=False): |
| 2627 | """ |
| 2628 | Load layer weights from a hdf5 group by layer name. |
| 2629 | |
| 2630 | Parameters |
| 2631 | ---------- |
| 2632 | f: hdf5 group |
| 2633 | A hdf5 group created by h5py.File() or create_group(). |
| 2634 | layers: list |
| 2635 | A list of layers to load weights. |
| 2636 | skip : boolean |
| 2637 | If 'skip' == True, loaded layer whose name is not found in 'layers' will be skipped. If 'skip' is False, |
| 2638 | error will be raised when mismatch is found. Default False. |
| 2639 | |
| 2640 | """ |
| 2641 | layer_names = [n.decode('utf8') for n in f.attrs["layer_names"]] |
| 2642 | layer_index = {layer.name: layer for layer in layers} |
| 2643 | |
| 2644 | for idx, name in enumerate(layer_names): |
| 2645 | if name not in layer_index.keys(): |
| 2646 | if skip: |
| 2647 | logging.warning("Layer named '%s' not found in network. Skip it." % name) |
| 2648 | else: |
| 2649 | raise RuntimeError( |
| 2650 | "Layer named '%s' not found in network. Hint: set argument skip=Ture " |
| 2651 | "if you want to skip redundant or mismatch Layers." % name |
| 2652 | ) |
| 2653 | else: |
| 2654 | g = f[name] |
| 2655 | layer = layer_index[name] |
| 2656 | if isinstance(layer, tl.models.Model): |
| 2657 | _load_weights_from_hdf5_group(g, layer.all_layers, skip) |
| 2658 | elif isinstance(layer, tl.layers.ModelLayer): |
| 2659 | _load_weights_from_hdf5_group(g, layer.model.all_layers, skip) |
| 2660 | elif isinstance(layer, tl.layers.LayerList): |
| 2661 | _load_weights_from_hdf5_group(g, layer.layers, skip) |
| 2662 | elif isinstance(layer, tl.layers.Layer): |
| 2663 | weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] |
| 2664 | for iid, w_name in enumerate(weight_names): |
| 2665 | # FIXME : this is only for compatibility |
| 2666 | if isinstance(layer, tl.layers.BatchNorm) and np.asarray(g[w_name]).ndim > 1: |
| 2667 | assign_tf_variable(layer.all_weights[iid], np.asarray(g[w_name]).squeeze()) |
| 2668 | continue |
| 2669 | assign_tf_variable(layer.all_weights[iid], np.asarray(g[w_name])) |
| 2670 | else: |
| 2671 | raise Exception("Only layer or model can be saved into hdf5.") |
| 2672 | |
| 2673 | |
| 2674 | def save_weights_to_hdf5(filepath, network): |
no test coverage detected
searching dependent graphs…