MCPcopy
hub / github.com/OpenPPL/ppq / build_engine

Function build_engine

ppq/utils/TensorRTUtil.py:212–263  ·  view source on GitHub ↗
(
    onnx_file: str, engine_file: str,
    fp16: bool = True, int8: bool = False, 
    int8_scale_file: str = None,
    explicit_batch: bool = True, 
    workspace: int = 4294967296, # 4GB
    )

Source from the content-addressed store, hash-verified

210
211
212def build_engine(
213 onnx_file: str, engine_file: str,
214 fp16: bool = True, int8: bool = False,
215 int8_scale_file: str = None,
216 explicit_batch: bool = True,
217 workspace: int = 4294967296, # 4GB
218 ):
219 TRT_LOGGER = trt.Logger()
220 """
221 Build a TensorRT Engine with given onnx model.
222
223 Flag int8, fp16 specifies the precision of layer:
224 For building FP32 engine: set int8 = False, fp16 = False, int8_scale_file = None
225 For building FP16 engine: set int8 = False, fp16 = True, int8_scale_file = None
226 For building INT8 engine: set True = False, fp16 = True, int8_scale_file = 'json file name'
227
228 """
229
230 if int8 is True:
231 if int8_scale_file is None:
232 raise ValueError('Build Quantized TensorRT Engine Requires a JSON file which specifies variable scales, '
233 'however int8_scale_file is None now.')
234
235 builder = trt.Builder(TRT_LOGGER)
236 if explicit_batch:
237 network = builder.create_network(1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
238 else: network = builder.create_network()
239
240 config = builder.create_builder_config()
241
242 parser = trt.OnnxParser(network, TRT_LOGGER)
243 config.max_workspace_size = workspace
244
245 if not os.path.exists(onnx_file):
246 raise FileNotFoundError(f'ONNX file {onnx_file} not found')
247
248 with open(onnx_file, 'rb') as model:
249 if not parser.parse(model.read()):
250 print('ERROR: Failed to parse the ONNX file.')
251 for error in range(parser.num_errors):
252 print(parser.get_error(error))
253 return None
254
255 if fp16: config.set_flag(trt.BuilderFlag.FP16)
256 if int8_scale_file is not None and int8:
257 config.set_flag(trt.BuilderFlag.INT8)
258 setDynamicRange(network, int8_scale_file)
259
260 engine = builder.build_engine(network, config)
261
262 with open(engine_file, "wb") as f:
263 f.write(engine.serialize())
264
265
266class MyProfiler(trt.IProfiler):

Callers 3

yolo_5.pyFile · 0.90
02_Quantization.pyFile · 0.90
Example_PTQ.pyFile · 0.90

Calls 3

setDynamicRangeFunction · 0.70
create_networkMethod · 0.45
parseMethod · 0.45

Tested by

no test coverage detected