MCPcopy Index your code
hub / github.com/tensorlayer/TensorLayer / _load_weights_from_hdf5_group

Function _load_weights_from_hdf5_group

tensorlayer/files/utils.py:2626–2671  ·  view source on GitHub ↗

Load layer weights from a hdf5 group by layer name. Parameters ---------- f: hdf5 group A hdf5 group created by h5py.File() or create_group(). layers: list A list of layers to load weights. skip : boolean If 'skip' == True, loaded layer whose name is

(f, layers, skip=False)

Source from the content-addressed store, hash-verified

2624
2625
2626def _load_weights_from_hdf5_group(f, layers, skip=False):
2627 """
2628 Load layer weights from a hdf5 group by layer name.
2629
2630 Parameters
2631 ----------
2632 f: hdf5 group
2633 A hdf5 group created by h5py.File() or create_group().
2634 layers: list
2635 A list of layers to load weights.
2636 skip : boolean
2637 If 'skip' == True, loaded layer whose name is not found in 'layers' will be skipped. If 'skip' is False,
2638 error will be raised when mismatch is found. Default False.
2639
2640 """
2641 layer_names = [n.decode('utf8') for n in f.attrs["layer_names"]]
2642 layer_index = {layer.name: layer for layer in layers}
2643
2644 for idx, name in enumerate(layer_names):
2645 if name not in layer_index.keys():
2646 if skip:
2647 logging.warning("Layer named '%s' not found in network. Skip it." % name)
2648 else:
2649 raise RuntimeError(
2650 "Layer named '%s' not found in network. Hint: set argument skip=Ture "
2651 "if you want to skip redundant or mismatch Layers." % name
2652 )
2653 else:
2654 g = f[name]
2655 layer = layer_index[name]
2656 if isinstance(layer, tl.models.Model):
2657 _load_weights_from_hdf5_group(g, layer.all_layers, skip)
2658 elif isinstance(layer, tl.layers.ModelLayer):
2659 _load_weights_from_hdf5_group(g, layer.model.all_layers, skip)
2660 elif isinstance(layer, tl.layers.LayerList):
2661 _load_weights_from_hdf5_group(g, layer.layers, skip)
2662 elif isinstance(layer, tl.layers.Layer):
2663 weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
2664 for iid, w_name in enumerate(weight_names):
2665 # FIXME : this is only for compatibility
2666 if isinstance(layer, tl.layers.BatchNorm) and np.asarray(g[w_name]).ndim > 1:
2667 assign_tf_variable(layer.all_weights[iid], np.asarray(g[w_name]).squeeze())
2668 continue
2669 assign_tf_variable(layer.all_weights[iid], np.asarray(g[w_name]))
2670 else:
2671 raise Exception("Only layer or model can be saved into hdf5.")
2672
2673
2674def save_weights_to_hdf5(filepath, network):

Callers 1

load_hdf5_to_weightsFunction · 0.85

Calls 1

assign_tf_variableFunction · 0.85

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…