Build the TensorRT engine and serialize it to disk. :param engine_path: The path where to serialize the engine to. :param precision: The datatype to use for the engine, either 'fp32', 'fp16' or 'int8'. :param calib_input: The path to a directory holding the calibrati
(self, engine_path, precision, calib_input=None, calib_cache=None, calib_num_images=25000,
calib_batch_size=8, calib_preprocessor=None)
| 225 | self.builder.max_batch_size = self.batch_size |
| 226 | |
| 227 | def create_engine(self, engine_path, precision, calib_input=None, calib_cache=None, calib_num_images=25000, |
| 228 | calib_batch_size=8, calib_preprocessor=None): |
| 229 | """ |
| 230 | Build the TensorRT engine and serialize it to disk. |
| 231 | :param engine_path: The path where to serialize the engine to. |
| 232 | :param precision: The datatype to use for the engine, either 'fp32', 'fp16' or 'int8'. |
| 233 | :param calib_input: The path to a directory holding the calibration images. |
| 234 | :param calib_cache: The path where to write the calibration cache to, or if it already exists, load it from. |
| 235 | :param calib_num_images: The maximum number of images to use for calibration. |
| 236 | :param calib_batch_size: The batch size to use for the calibration process. |
| 237 | :param calib_preprocessor: The ImageBatcher preprocessor algorithm to use. |
| 238 | """ |
| 239 | engine_path = os.path.realpath(engine_path) |
| 240 | engine_dir = os.path.dirname(engine_path) |
| 241 | os.makedirs(engine_dir, exist_ok=True) |
| 242 | log.info("Building {} Engine in {}".format(precision, engine_path)) |
| 243 | |
| 244 | inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)] |
| 245 | |
| 246 | if precision == "fp16": |
| 247 | if not self.builder.platform_has_fast_fp16: |
| 248 | log.warning("FP16 is not supported natively on this platform/device") |
| 249 | else: |
| 250 | self.config.set_flag(trt.BuilderFlag.FP16) |
| 251 | |
| 252 | with self.builder.build_engine(self.network, self.config) as engine, open(engine_path, "wb") as f: |
| 253 | log.info("Serializing engine to file: {:}".format(engine_path)) |
| 254 | f.write(engine.serialize()) |
no test coverage detected