Create a :class:`SessionInit` to be loaded to a session, automatically from any supported objects, with some smart heuristics. The object can be: + A TF checkpoint + A dict of numpy arrays + A npz file, to be interpreted as a dict + An empty string or None, in which cas
(obj, *, ignore_mismatch=False)
| 265 | |
| 266 | |
| 267 | def SmartInit(obj, *, ignore_mismatch=False): |
| 268 | """ |
| 269 | Create a :class:`SessionInit` to be loaded to a session, |
| 270 | automatically from any supported objects, with some smart heuristics. |
| 271 | The object can be: |
| 272 | |
| 273 | + A TF checkpoint |
| 274 | + A dict of numpy arrays |
| 275 | + A npz file, to be interpreted as a dict |
| 276 | + An empty string or None, in which case the sessinit will be a no-op |
| 277 | + A list of supported objects, to be initialized one by one |
| 278 | |
| 279 | Args: |
| 280 | obj: a supported object |
| 281 | ignore_mismatch (bool): ignore failures when the value and the |
| 282 | variable does not match in their shapes. |
| 283 | If False, it will throw exception on such errors. |
| 284 | If True, it will only print a warning. |
| 285 | |
| 286 | Returns: |
| 287 | SessionInit: |
| 288 | """ |
| 289 | if not obj: |
| 290 | return JustCurrentSession() |
| 291 | if isinstance(obj, list): |
| 292 | return ChainInit([SmartInit(x, ignore_mismatch=ignore_mismatch) for x in obj]) |
| 293 | if isinstance(obj, six.string_types): |
| 294 | obj = os.path.expanduser(obj) |
| 295 | if obj.endswith(".npy") or obj.endswith(".npz"): |
| 296 | assert tf.gfile.Exists(obj), "File {} does not exist!".format(obj) |
| 297 | filename = obj |
| 298 | logger.info("Loading dictionary from {} ...".format(filename)) |
| 299 | if filename.endswith('.npy'): |
| 300 | obj = np.load(filename, encoding='latin1').item() |
| 301 | elif filename.endswith('.npz'): |
| 302 | obj = dict(np.load(filename)) |
| 303 | elif len(tf.gfile.Glob(obj + "*")): |
| 304 | # Assume to be a TF checkpoint. |
| 305 | # A TF checkpoint must be a prefix of an actual file. |
| 306 | return (SaverRestoreRelaxed if ignore_mismatch else SaverRestore)(obj) |
| 307 | else: |
| 308 | raise ValueError("Invalid argument to SmartInit: " + obj) |
| 309 | |
| 310 | if isinstance(obj, dict): |
| 311 | return DictRestore(obj, ignore_mismatch=ignore_mismatch) |
| 312 | raise ValueError("Invalid argument to SmartInit: " + type(obj)) |
| 313 | |
| 314 | |
| 315 | get_model_loader = SmartInit |
no test coverage detected