MCPcopy
hub / github.com/tensorlayer/TensorLayer / restore_model

Function restore_model

tensorlayer/models/vgg.py:157–181  ·  view source on GitHub ↗
(model, layer_type)

Source from the content-addressed store, hash-verified

155
156
157def restore_model(model, layer_type):
158 logging.info("Restore pre-trained weights")
159 # download weights
160 maybe_download_and_extract(model_saved_name[layer_type], 'models', model_urls[layer_type])
161 weights = []
162 if layer_type == 'vgg16':
163 npz = np.load(os.path.join('models', model_saved_name[layer_type]), allow_pickle=True)
164 # get weight list
165 for val in sorted(npz.items()):
166 logging.info(" Loading weights %s in %s" % (str(val[1].shape), val[0]))
167 weights.append(val[1])
168 if len(model.all_weights) == len(weights):
169 break
170 elif layer_type == 'vgg19':
171 npz = np.load(os.path.join('models', model_saved_name[layer_type]), allow_pickle=True, encoding='latin1').item()
172 # get weight list
173 for val in sorted(npz.items()):
174 logging.info(" Loading %s in %s" % (str(val[1][0].shape), val[0]))
175 logging.info(" Loading %s in %s" % (str(val[1][1].shape), val[0]))
176 weights.extend(val[1])
177 if len(model.all_weights) == len(weights):
178 break
179 # assign weight values
180 assign_weights(weights, model)
181 del weights
182
183
184def VGG_static(layer_type, batch_norm=False, end_with='outputs', name=None):

Callers 2

vgg16Function · 0.85
vgg19Function · 0.85

Calls 3

assign_weightsFunction · 0.90
loadMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…