A callback that runs a list of :class:`Inferencer` on some :class:`InputSource`.
| 101 | |
| 102 | |
| 103 | class InferenceRunner(InferenceRunnerBase): |
| 104 | """ |
| 105 | A callback that runs a list of :class:`Inferencer` on some :class:`InputSource`. |
| 106 | """ |
| 107 | |
| 108 | def __init__(self, input, infs, tower_name='InferenceTower', tower_func=None, device=0): |
| 109 | """ |
| 110 | Args: |
| 111 | input (InputSource or DataFlow): The :class:`InputSource` to run |
| 112 | inference on. If given a DataFlow, will use :class:`FeedInput`. |
| 113 | infs (list): a list of :class:`Inferencer` instances. |
| 114 | tower_name (str): the name scope of the tower to build. |
| 115 | If multiple InferenceRunner are used, each needs a different tower_name. |
| 116 | tower_func (tfutils.TowerFunc or None): the tower function to be used to build the graph. |
| 117 | By defaults to call `trainer.tower_func` under a `training=False` TowerContext, |
| 118 | but you can change it to a different tower function |
| 119 | if you need to inference with several different graphs. |
| 120 | device (int): the device to use |
| 121 | """ |
| 122 | if isinstance(input, DataFlow): |
| 123 | # use infinite=False so that a dataflow without size will stop normally |
| 124 | # TODO a better way to handle inference size |
| 125 | input = FeedInput(input, infinite=False) |
| 126 | assert isinstance(input, InputSource), input |
| 127 | assert not isinstance(input, StagingInput), input |
| 128 | self._tower_name = tower_name |
| 129 | self._device_id = device |
| 130 | self._device = _device_from_int(device) |
| 131 | self._tower_func = tower_func |
| 132 | super(InferenceRunner, self).__init__(input, infs) |
| 133 | |
| 134 | def _build_hook(self, inf): |
| 135 | out_names = inf.get_fetches() |
| 136 | fetches = self._tower_handle.get_tensors(out_names) |
| 137 | return InferencerToHook(inf, fetches) |
| 138 | |
| 139 | def _setup_graph(self): |
| 140 | if self._tower_func is None: |
| 141 | assert self.trainer.tower_func is not None, "You must set tower_func of the trainer to use InferenceRunner!" |
| 142 | self._tower_func = self.trainer.tower_func |
| 143 | input_callbacks = self._input_source.setup(self._tower_func.input_signature) |
| 144 | |
| 145 | vs_name = self.trainer._vs_name_for_predictor(self._device_id) |
| 146 | logger.info("[InferenceRunner] Building tower '{}' on device {} {}...".format( |
| 147 | self._tower_name, self._device, |
| 148 | "with variable scope '{}'".format(vs_name) if vs_name else '')) |
| 149 | with tf.variable_scope(tf.get_variable_scope(), reuse=True), \ |
| 150 | tf.device(self._device), \ |
| 151 | PredictTowerContext(self._tower_name, vs_name=vs_name): |
| 152 | self._tower_func(*self._input_source.get_input_tensors()) |
| 153 | self._tower_handle = self._tower_func.towers[-1] |
| 154 | |
| 155 | for h in [self._build_hook(inf) for inf in self.infs]: |
| 156 | self.register_hook(h) |
| 157 | # trigger_{step,epoch}, {before,after}_epoch is ignored. |
| 158 | # We assume that InputSource callbacks won't use these methods |
| 159 | self._input_callbacks = Callbacks(input_callbacks) |
| 160 | for h in self._input_callbacks.get_hooks(): |
no outgoing calls
no test coverage detected