MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / KerasPhaseCallback

Class KerasPhaseCallback

tensorpack/contrib/keras.py:114–141  ·  view source on GitHub ↗

Keras needs an extra input if learning_phase is used by the model This callback will be used: 1. By the trainer with isTrain=True 2. By InferenceRunner with isTrain=False, in the form of hooks If you use :class:`KerasModel` or :func:`setup_keras_trainer`, this callback will

Source from the content-addressed store, hash-verified

112
113
114class KerasPhaseCallback(Callback):
115 """
116 Keras needs an extra input if learning_phase is used by the model
117 This callback will be used:
118 1. By the trainer with isTrain=True
119 2. By InferenceRunner with isTrain=False, in the form of hooks
120
121 If you use :class:`KerasModel` or :func:`setup_keras_trainer`,
122 this callback will be automatically added when needed.
123 """
124 def __init__(self, isTrain):
125 assert isinstance(isTrain, bool), isTrain
126 self._isTrain = isTrain
127 self._learning_phase = keras.backend.learning_phase()
128
129 def _setup_graph(self):
130 logger.info("Using Keras learning phase {} in the graph!".format(
131 self._learning_phase.name))
132 cbs = self.trainer._callbacks.cbs
133 for cb in cbs:
134 # XXX HACK
135 if isinstance(cb, InferenceRunnerBase):
136 h = CallbackToHook(KerasPhaseCallback(False))
137 cb.register_hook(h)
138
139 def _before_run(self, ctx):
140 return tf.train.SessionRunArgs(
141 fetches=[], feed_dict={self._learning_phase: int(self._isTrain)})
142
143
144def setup_keras_trainer(

Callers 3

mnist-keras.pyFile · 0.90
_setup_graphMethod · 0.85
setup_keras_trainerFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected