Keras needs an extra input if learning_phase is used by the model This callback will be used: 1. By the trainer with isTrain=True 2. By InferenceRunner with isTrain=False, in the form of hooks If you use :class:`KerasModel` or :func:`setup_keras_trainer`, this callback will
| 112 | |
| 113 | |
| 114 | class KerasPhaseCallback(Callback): |
| 115 | """ |
| 116 | Keras needs an extra input if learning_phase is used by the model |
| 117 | This callback will be used: |
| 118 | 1. By the trainer with isTrain=True |
| 119 | 2. By InferenceRunner with isTrain=False, in the form of hooks |
| 120 | |
| 121 | If you use :class:`KerasModel` or :func:`setup_keras_trainer`, |
| 122 | this callback will be automatically added when needed. |
| 123 | """ |
| 124 | def __init__(self, isTrain): |
| 125 | assert isinstance(isTrain, bool), isTrain |
| 126 | self._isTrain = isTrain |
| 127 | self._learning_phase = keras.backend.learning_phase() |
| 128 | |
| 129 | def _setup_graph(self): |
| 130 | logger.info("Using Keras learning phase {} in the graph!".format( |
| 131 | self._learning_phase.name)) |
| 132 | cbs = self.trainer._callbacks.cbs |
| 133 | for cb in cbs: |
| 134 | # XXX HACK |
| 135 | if isinstance(cb, InferenceRunnerBase): |
| 136 | h = CallbackToHook(KerasPhaseCallback(False)) |
| 137 | cb.register_hook(h) |
| 138 | |
| 139 | def _before_run(self, ctx): |
| 140 | return tf.train.SessionRunArgs( |
| 141 | fetches=[], feed_dict={self._learning_phase: int(self._isTrain)}) |
| 142 | |
| 143 | |
| 144 | def setup_keras_trainer( |
no outgoing calls
no test coverage detected