(
onnx_file: str, engine_file: str,
fp16: bool = True, int8: bool = False,
int8_scale_file: str = None,
explicit_batch: bool = True,
workspace: int = 4294967296, # 4GB
)
| 210 | |
| 211 | |
| 212 | def build_engine( |
| 213 | onnx_file: str, engine_file: str, |
| 214 | fp16: bool = True, int8: bool = False, |
| 215 | int8_scale_file: str = None, |
| 216 | explicit_batch: bool = True, |
| 217 | workspace: int = 4294967296, # 4GB |
| 218 | ): |
| 219 | TRT_LOGGER = trt.Logger() |
| 220 | """ |
| 221 | Build a TensorRT Engine with given onnx model. |
| 222 | |
| 223 | Flag int8, fp16 specifies the precision of layer: |
| 224 | For building FP32 engine: set int8 = False, fp16 = False, int8_scale_file = None |
| 225 | For building FP16 engine: set int8 = False, fp16 = True, int8_scale_file = None |
| 226 | For building INT8 engine: set True = False, fp16 = True, int8_scale_file = 'json file name' |
| 227 | |
| 228 | """ |
| 229 | |
| 230 | if int8 is True: |
| 231 | if int8_scale_file is None: |
| 232 | raise ValueError('Build Quantized TensorRT Engine Requires a JSON file which specifies variable scales, ' |
| 233 | 'however int8_scale_file is None now.') |
| 234 | |
| 235 | builder = trt.Builder(TRT_LOGGER) |
| 236 | if explicit_batch: |
| 237 | network = builder.create_network(1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) |
| 238 | else: network = builder.create_network() |
| 239 | |
| 240 | config = builder.create_builder_config() |
| 241 | |
| 242 | parser = trt.OnnxParser(network, TRT_LOGGER) |
| 243 | config.max_workspace_size = workspace |
| 244 | |
| 245 | if not os.path.exists(onnx_file): |
| 246 | raise FileNotFoundError(f'ONNX file {onnx_file} not found') |
| 247 | |
| 248 | with open(onnx_file, 'rb') as model: |
| 249 | if not parser.parse(model.read()): |
| 250 | print('ERROR: Failed to parse the ONNX file.') |
| 251 | for error in range(parser.num_errors): |
| 252 | print(parser.get_error(error)) |
| 253 | return None |
| 254 | |
| 255 | if fp16: config.set_flag(trt.BuilderFlag.FP16) |
| 256 | if int8_scale_file is not None and int8: |
| 257 | config.set_flag(trt.BuilderFlag.INT8) |
| 258 | setDynamicRange(network, int8_scale_file) |
| 259 | |
| 260 | engine = builder.build_engine(network, config) |
| 261 | |
| 262 | with open(engine_file, "wb") as f: |
| 263 | f.write(engine.serialize()) |
| 264 | |
| 265 | |
| 266 | class MyProfiler(trt.IProfiler): |
no test coverage detected