MCPcopy
hub / github.com/tensorlayer/TensorLayer / restore_params

Function restore_params

tensorlayer/models/resnet.py:179–203  ·  view source on GitHub ↗
(network, path='models')

Source from the content-addressed store, hash-verified

177
178
179def restore_params(network, path='models'):
180 logging.info("Restore pre-trained parameters")
181 maybe_download_and_extract(
182 'resnet50_weights_tf_dim_ordering_tf_kernels.h5',
183 path,
184 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/',
185 ) # ls -al
186 try:
187 import h5py
188 except Exception:
189 raise ImportError('h5py not imported')
190
191 f = h5py.File(os.path.join(path, 'resnet50_weights_tf_dim_ordering_tf_kernels.h5'), 'r')
192
193 for layer in network.all_layers:
194 if len(layer.all_weights) == 0:
195 continue
196 w_names = list(f[layer.name])
197 params = [f[layer.name][n][:] for n in w_names]
198 # if 'bn' in layer.name:
199 # params = [x.reshape(1, 1, 1, -1) for x in params]
200 assign_weights(params, layer)
201 del params
202
203 f.close()

Callers 1

ResNet50Function · 0.70

Calls 3

assign_weightsFunction · 0.90
closeMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…