MCPcopy
hub / github.com/ultralytics/yolov5 / export_engine

Function export_engine

export.py:242–300  ·  view source on GitHub ↗
(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:'))

Source from the content-addressed store, hash-verified

240
241@try_export
242def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
243 # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
244 assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
245 try:
246 import tensorrt as trt
247 except Exception:
248 if platform.system() == 'Linux':
249 check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
250 import tensorrt as trt
251
252 if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
253 grid = model.model[-1].anchor_grid
254 model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
255 export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
256 model.model[-1].anchor_grid = grid
257 else: # TensorRT >= 8
258 check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
259 export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
260 onnx = file.with_suffix('.onnx')
261
262 LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
263 assert onnx.exists(), f'failed to export ONNX file: {onnx}'
264 f = file.with_suffix('.engine') # TensorRT engine file
265 logger = trt.Logger(trt.Logger.INFO)
266 if verbose:
267 logger.min_severity = trt.Logger.Severity.VERBOSE
268
269 builder = trt.Builder(logger)
270 config = builder.create_builder_config()
271 config.max_workspace_size = workspace * 1 << 30
272 # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
273
274 flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
275 network = builder.create_network(flag)
276 parser = trt.OnnxParser(network, logger)
277 if not parser.parse_from_file(str(onnx)):
278 raise RuntimeError(f'failed to load ONNX file: {onnx}')
279
280 inputs = [network.get_input(i) for i in range(network.num_inputs)]
281 outputs = [network.get_output(i) for i in range(network.num_outputs)]
282 for inp in inputs:
283 LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
284 for out in outputs:
285 LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
286
287 if dynamic:
288 if im.shape[0] <= 1:
289 LOGGER.warning(f"{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument")
290 profile = builder.create_optimization_profile()
291 for inp in inputs:
292 profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
293 config.add_optimization_profile(profile)
294
295 LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')
296 if builder.platform_has_fast_fp16 and half:
297 config.set_flag(trt.BuilderFlag.FP16)
298 with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
299 t.write(engine.serialize())

Callers 1

runFunction · 0.85

Calls 5

colorstrFunction · 0.90
check_requirementsFunction · 0.90
check_versionFunction · 0.90
export_onnxFunction · 0.85
infoMethod · 0.80

Tested by

no test coverage detected