MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / InferenceRunner

Class InferenceRunner

tensorpack/callbacks/inference_runner.py:103–180  ·  view source on GitHub ↗

A callback that runs a list of :class:`Inferencer` on some :class:`InputSource`.

Source from the content-addressed store, hash-verified

101
102
103class InferenceRunner(InferenceRunnerBase):
104 """
105 A callback that runs a list of :class:`Inferencer` on some :class:`InputSource`.
106 """
107
108 def __init__(self, input, infs, tower_name='InferenceTower', tower_func=None, device=0):
109 """
110 Args:
111 input (InputSource or DataFlow): The :class:`InputSource` to run
112 inference on. If given a DataFlow, will use :class:`FeedInput`.
113 infs (list): a list of :class:`Inferencer` instances.
114 tower_name (str): the name scope of the tower to build.
115 If multiple InferenceRunner are used, each needs a different tower_name.
116 tower_func (tfutils.TowerFunc or None): the tower function to be used to build the graph.
117 By defaults to call `trainer.tower_func` under a `training=False` TowerContext,
118 but you can change it to a different tower function
119 if you need to inference with several different graphs.
120 device (int): the device to use
121 """
122 if isinstance(input, DataFlow):
123 # use infinite=False so that a dataflow without size will stop normally
124 # TODO a better way to handle inference size
125 input = FeedInput(input, infinite=False)
126 assert isinstance(input, InputSource), input
127 assert not isinstance(input, StagingInput), input
128 self._tower_name = tower_name
129 self._device_id = device
130 self._device = _device_from_int(device)
131 self._tower_func = tower_func
132 super(InferenceRunner, self).__init__(input, infs)
133
134 def _build_hook(self, inf):
135 out_names = inf.get_fetches()
136 fetches = self._tower_handle.get_tensors(out_names)
137 return InferencerToHook(inf, fetches)
138
139 def _setup_graph(self):
140 if self._tower_func is None:
141 assert self.trainer.tower_func is not None, "You must set tower_func of the trainer to use InferenceRunner!"
142 self._tower_func = self.trainer.tower_func
143 input_callbacks = self._input_source.setup(self._tower_func.input_signature)
144
145 vs_name = self.trainer._vs_name_for_predictor(self._device_id)
146 logger.info("[InferenceRunner] Building tower '{}' on device {} {}...".format(
147 self._tower_name, self._device,
148 "with variable scope '{}'".format(vs_name) if vs_name else ''))
149 with tf.variable_scope(tf.get_variable_scope(), reuse=True), \
150 tf.device(self._device), \
151 PredictTowerContext(self._tower_name, vs_name=vs_name):
152 self._tower_func(*self._input_source.get_input_tensors())
153 self._tower_handle = self._tower_func.towers[-1]
154
155 for h in [self._build_hook(inf) for inf in self.infs]:
156 self.register_hook(h)
157 # trigger_{step,epoch}, {before,after}_epoch is ignored.
158 # We assume that InputSource callbacks won't use these methods
159 self._input_callbacks = Callbacks(input_callbacks)
160 for h in self._input_callbacks.get_hooks():

Callers 15

fitMethod · 0.85
get_configFunction · 0.85
get_configFunction · 0.85
get_configFunction · 0.85
cifar10-resnet.pyFile · 0.85
get_configFunction · 0.85
get_configFunction · 0.85
get_configFunction · 0.85
mnist-tflayers.pyFile · 0.85
get_configFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected