(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:'))
| 240 | |
| 241 | @try_export |
| 242 | def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')): |
| 243 | # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt |
| 244 | assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`' |
| 245 | try: |
| 246 | import tensorrt as trt |
| 247 | except Exception: |
| 248 | if platform.system() == 'Linux': |
| 249 | check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com') |
| 250 | import tensorrt as trt |
| 251 | |
| 252 | if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012 |
| 253 | grid = model.model[-1].anchor_grid |
| 254 | model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid] |
| 255 | export_onnx(model, im, file, 12, dynamic, simplify) # opset 12 |
| 256 | model.model[-1].anchor_grid = grid |
| 257 | else: # TensorRT >= 8 |
| 258 | check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0 |
| 259 | export_onnx(model, im, file, 12, dynamic, simplify) # opset 12 |
| 260 | onnx = file.with_suffix('.onnx') |
| 261 | |
| 262 | LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...') |
| 263 | assert onnx.exists(), f'failed to export ONNX file: {onnx}' |
| 264 | f = file.with_suffix('.engine') # TensorRT engine file |
| 265 | logger = trt.Logger(trt.Logger.INFO) |
| 266 | if verbose: |
| 267 | logger.min_severity = trt.Logger.Severity.VERBOSE |
| 268 | |
| 269 | builder = trt.Builder(logger) |
| 270 | config = builder.create_builder_config() |
| 271 | config.max_workspace_size = workspace * 1 << 30 |
| 272 | # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice |
| 273 | |
| 274 | flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) |
| 275 | network = builder.create_network(flag) |
| 276 | parser = trt.OnnxParser(network, logger) |
| 277 | if not parser.parse_from_file(str(onnx)): |
| 278 | raise RuntimeError(f'failed to load ONNX file: {onnx}') |
| 279 | |
| 280 | inputs = [network.get_input(i) for i in range(network.num_inputs)] |
| 281 | outputs = [network.get_output(i) for i in range(network.num_outputs)] |
| 282 | for inp in inputs: |
| 283 | LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}') |
| 284 | for out in outputs: |
| 285 | LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}') |
| 286 | |
| 287 | if dynamic: |
| 288 | if im.shape[0] <= 1: |
| 289 | LOGGER.warning(f"{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument") |
| 290 | profile = builder.create_optimization_profile() |
| 291 | for inp in inputs: |
| 292 | profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape) |
| 293 | config.add_optimization_profile(profile) |
| 294 | |
| 295 | LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}') |
| 296 | if builder.platform_has_fast_fp16 and half: |
| 297 | config.set_flag(trt.BuilderFlag.FP16) |
| 298 | with builder.build_engine(network, config) as engine, open(f, 'wb') as t: |
| 299 | t.write(engine.serialize()) |
no test coverage detected