MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / NewSessionCreator

Class NewSessionCreator

tensorpack/tfutils/sesscreate.py:35–91  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

33
34
35class NewSessionCreator(tf.train.SessionCreator):
36 def __init__(self, target='', config=None):
37 """
38 Args:
39 target, config: same as :meth:`Session.__init__()`.
40 config: a :class:`tf.ConfigProto` instance, defaults to :func:`tfutils.get_default_sess_config()`
41 """
42 self.target = target
43
44 if config is None:
45 # distributed trainer doesn't support user-provided config
46 # we set this attribute so that they can check
47 self.user_provided_config = False
48 config = get_default_sess_config()
49 else:
50 self.user_provided_config = True
51 logger.warn(_WRN1)
52 logger.warn(_WRN2)
53 self.config = config
54
55 def create_session(self):
56 sess = tf.Session(target=self.target, config=self.config)
57
58 def blocking_op(x):
59 """
60 Whether an op is possibly blocking.
61 """
62 if x.op_def is not None and not x.op_def.is_stateful:
63 return False
64 if "Dequeue" in x.type or "Enqueue" in x.type:
65 return True
66 if "Unstage" in x.type:
67 return True
68 if x.type in ["ZMQPull"]:
69 return True
70 return False
71
72 def run(op):
73 try:
74 from tensorflow.contrib.graph_editor import get_backward_walk_ops # deprecated
75 except ImportError:
76 from tensorflow.python.ops.op_selector import get_backward_walk_ops
77
78 deps = get_backward_walk_ops(op, control_inputs=True)
79 for dep_op in deps:
80 if blocking_op(dep_op):
81 logger.warn(
82 "Initializer '{}' depends on a blocking op '{}'. "
83 "This initializer is likely to hang!".format(
84 op.name, dep_op.name))
85
86 sess.run(op)
87
88 run(tf.global_variables_initializer())
89 run(tf.local_variables_initializer())
90 run(tf.tables_initializer())
91 return sess
92

Callers 3

__init__Method · 0.85
__init__Method · 0.85
train_with_defaultsMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected