MCPcopy Index your code
hub / github.com/tensorlayer/TensorLayer / load_hdf5_graph

Function load_hdf5_graph

tensorlayer/files/utils.py:299–352  ·  view source on GitHub ↗

Restore TL model archtecture from a a pickle file. Support loading model weights. Parameters ----------- filepath : str The name of model file. load_weights : bool Whether to load model weights. Returns -------- network : TensorLayer Model. Examples

(filepath='model.hdf5', load_weights=False)

Source from the content-addressed store, hash-verified

297
298
299def load_hdf5_graph(filepath='model.hdf5', load_weights=False):
300 """Restore TL model archtecture from a a pickle file. Support loading model weights.
301
302 Parameters
303 -----------
304 filepath : str
305 The name of model file.
306 load_weights : bool
307 Whether to load model weights.
308
309 Returns
310 --------
311 network : TensorLayer Model.
312
313 Examples
314 --------
315 - see ``tl.files.save_hdf5_graph``
316 """
317 logging.info("[*] Loading TL model from {}, loading weights={}".format(filepath, load_weights))
318
319 f = h5py.File(filepath, 'r')
320
321 model_config_str = f.attrs["model_config"].decode('utf8')
322 model_config = eval(model_config_str)
323
324 # version_info_str = f.attrs["version_info"].decode('utf8')
325 # version_info = eval(version_info_str)
326 version_info = model_config["version_info"]
327 backend_version = version_info["backend_version"]
328 tensorlayer_version = version_info["tensorlayer_version"]
329 if backend_version != tf.__version__:
330 logging.warning(
331 "Saved model uses tensorflow version {}, but now you are using tensorflow version {}".format(
332 backend_version, tf.__version__
333 )
334 )
335 if tensorlayer_version != tl.__version__:
336 logging.warning(
337 "Saved model uses tensorlayer version {}, but now you are using tensorlayer version {}".format(
338 tensorlayer_version, tl.__version__
339 )
340 )
341
342 M = static_graph2net(model_config)
343 if load_weights:
344 if not ('layer_names' in f.attrs.keys()):
345 raise RuntimeError("Saved model does not contain weights.")
346 M.load_weights(filepath=filepath)
347
348 f.close()
349
350 logging.info("[*] Loaded TL model from {}, loading weights={}".format(filepath, load_weights))
351
352 return M
353
354
355# def load_pkl_graph(name='model.pkl'):

Callers

nothing calls this directly

Calls 3

static_graph2netFunction · 0.85
closeMethod · 0.80
load_weightsMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…