MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / SmartInit

Function SmartInit

tensorpack/tfutils/sessinit.py:267–312  ·  view source on GitHub ↗

Create a :class:`SessionInit` to be loaded to a session, automatically from any supported objects, with some smart heuristics. The object can be: + A TF checkpoint + A dict of numpy arrays + A npz file, to be interpreted as a dict + An empty string or None, in which cas

(obj, *, ignore_mismatch=False)

Source from the content-addressed store, hash-verified

265
266
267def SmartInit(obj, *, ignore_mismatch=False):
268 """
269 Create a :class:`SessionInit` to be loaded to a session,
270 automatically from any supported objects, with some smart heuristics.
271 The object can be:
272
273 + A TF checkpoint
274 + A dict of numpy arrays
275 + A npz file, to be interpreted as a dict
276 + An empty string or None, in which case the sessinit will be a no-op
277 + A list of supported objects, to be initialized one by one
278
279 Args:
280 obj: a supported object
281 ignore_mismatch (bool): ignore failures when the value and the
282 variable does not match in their shapes.
283 If False, it will throw exception on such errors.
284 If True, it will only print a warning.
285
286 Returns:
287 SessionInit:
288 """
289 if not obj:
290 return JustCurrentSession()
291 if isinstance(obj, list):
292 return ChainInit([SmartInit(x, ignore_mismatch=ignore_mismatch) for x in obj])
293 if isinstance(obj, six.string_types):
294 obj = os.path.expanduser(obj)
295 if obj.endswith(".npy") or obj.endswith(".npz"):
296 assert tf.gfile.Exists(obj), "File {} does not exist!".format(obj)
297 filename = obj
298 logger.info("Loading dictionary from {} ...".format(filename))
299 if filename.endswith('.npy'):
300 obj = np.load(filename, encoding='latin1').item()
301 elif filename.endswith('.npz'):
302 obj = dict(np.load(filename))
303 elif len(tf.gfile.Glob(obj + "*")):
304 # Assume to be a TF checkpoint.
305 # A TF checkpoint must be a prefix of an actual file.
306 return (SaverRestoreRelaxed if ignore_mismatch else SaverRestore)(obj)
307 else:
308 raise ValueError("Invalid argument to SmartInit: " + obj)
309
310 if isinstance(obj, dict):
311 return DictRestore(obj, ignore_mismatch=ignore_mismatch)
312 raise ValueError("Invalid argument to SmartInit: " + type(obj))
313
314
315get_model_loader = SmartInit

Callers 15

evaluate_rcnnFunction · 0.90
checkpoint-prof.pyFile · 0.90
imagenet-resnet.pyFile · 0.90
do_visualizeFunction · 0.90
predict.pyFile · 0.90
shufflenet.pyFile · 0.90
alexnet-dorefa.pyFile · 0.90
boilerplate.pyFile · 0.85
DiscoGAN-CelebA.pyFile · 0.85
WGAN.pyFile · 0.85
sampleFunction · 0.85
Image2Image.pyFile · 0.85

Calls 5

JustCurrentSessionClass · 0.85
ChainInitClass · 0.85
DictRestoreClass · 0.85
formatMethod · 0.80
loadMethod · 0.45

Tested by

no test coverage detected