| 202 | |
| 203 | |
| 204 | class TfPreprocessTransform: |
| 205 | |
| 206 | def __init__(self, is_training=False, size=224, interpolation='bicubic'): |
| 207 | self.is_training = is_training |
| 208 | self.size = size[0] if isinstance(size, tuple) else size |
| 209 | self.interpolation = interpolation |
| 210 | self._image_bytes = None |
| 211 | self.process_image = self._build_tf_graph() |
| 212 | self.sess = None |
| 213 | |
| 214 | def _build_tf_graph(self): |
| 215 | with tf.device('/cpu:0'): |
| 216 | self._image_bytes = tf.placeholder( |
| 217 | shape=[], |
| 218 | dtype=tf.string, |
| 219 | ) |
| 220 | img = preprocess_image( |
| 221 | self._image_bytes, self.is_training, False, self.size, self.interpolation) |
| 222 | return img |
| 223 | |
| 224 | def __call__(self, image_bytes): |
| 225 | if self.sess is None: |
| 226 | self.sess = tf.Session() |
| 227 | img = self.sess.run(self.process_image, feed_dict={self._image_bytes: image_bytes}) |
| 228 | img = img.round().clip(0, 255).astype(np.uint8) |
| 229 | if img.ndim < 3: |
| 230 | img = np.expand_dims(img, axis=-1) |
| 231 | img = np.rollaxis(img, 2) # HWC to CHW |
| 232 | return img |