Restore TL model archtecture from a a pickle file. Support loading model weights. Parameters ----------- filepath : str The name of model file. load_weights : bool Whether to load model weights. Returns -------- network : TensorLayer Model. Examples
(filepath='model.hdf5', load_weights=False)
| 297 | |
| 298 | |
| 299 | def load_hdf5_graph(filepath='model.hdf5', load_weights=False): |
| 300 | """Restore TL model archtecture from a a pickle file. Support loading model weights. |
| 301 | |
| 302 | Parameters |
| 303 | ----------- |
| 304 | filepath : str |
| 305 | The name of model file. |
| 306 | load_weights : bool |
| 307 | Whether to load model weights. |
| 308 | |
| 309 | Returns |
| 310 | -------- |
| 311 | network : TensorLayer Model. |
| 312 | |
| 313 | Examples |
| 314 | -------- |
| 315 | - see ``tl.files.save_hdf5_graph`` |
| 316 | """ |
| 317 | logging.info("[*] Loading TL model from {}, loading weights={}".format(filepath, load_weights)) |
| 318 | |
| 319 | f = h5py.File(filepath, 'r') |
| 320 | |
| 321 | model_config_str = f.attrs["model_config"].decode('utf8') |
| 322 | model_config = eval(model_config_str) |
| 323 | |
| 324 | # version_info_str = f.attrs["version_info"].decode('utf8') |
| 325 | # version_info = eval(version_info_str) |
| 326 | version_info = model_config["version_info"] |
| 327 | backend_version = version_info["backend_version"] |
| 328 | tensorlayer_version = version_info["tensorlayer_version"] |
| 329 | if backend_version != tf.__version__: |
| 330 | logging.warning( |
| 331 | "Saved model uses tensorflow version {}, but now you are using tensorflow version {}".format( |
| 332 | backend_version, tf.__version__ |
| 333 | ) |
| 334 | ) |
| 335 | if tensorlayer_version != tl.__version__: |
| 336 | logging.warning( |
| 337 | "Saved model uses tensorlayer version {}, but now you are using tensorlayer version {}".format( |
| 338 | tensorlayer_version, tl.__version__ |
| 339 | ) |
| 340 | ) |
| 341 | |
| 342 | M = static_graph2net(model_config) |
| 343 | if load_weights: |
| 344 | if not ('layer_names' in f.attrs.keys()): |
| 345 | raise RuntimeError("Saved model does not contain weights.") |
| 346 | M.load_weights(filepath=filepath) |
| 347 | |
| 348 | f.close() |
| 349 | |
| 350 | logging.info("[*] Loaded TL model from {}, loading weights={}".format(filepath, load_weights)) |
| 351 | |
| 352 | return M |
| 353 | |
| 354 | |
| 355 | # def load_pkl_graph(name='model.pkl'): |
nothing calls this directly
no test coverage detected
searching dependent graphs…