Build TensorRT engine from ONNX model. Args: model_params (dict): Optional model specific parameters, e.g.: - qwen2_vl_dim (int): Dimension for Qwen2-VL model - min_hw_dims (int): Minimum HW dimensions - max_hw_dims (int): Maximum HW dimensions
(model_type,
input_sizes,
onnx_dir,
engine_dir,
max_batch_size,
dtype=torch.float16,
model_params=None,
onnx_name='model.onnx',
engine_name='model.engine',
delete_onnx=True,
logger=trt.Logger(trt.Logger.INFO))
| 176 | |
| 177 | |
| 178 | def build_trt_engine(model_type, |
| 179 | input_sizes, |
| 180 | onnx_dir, |
| 181 | engine_dir, |
| 182 | max_batch_size, |
| 183 | dtype=torch.float16, |
| 184 | model_params=None, |
| 185 | onnx_name='model.onnx', |
| 186 | engine_name='model.engine', |
| 187 | delete_onnx=True, |
| 188 | logger=trt.Logger(trt.Logger.INFO)): |
| 189 | """Build TensorRT engine from ONNX model. |
| 190 | |
| 191 | Args: |
| 192 | model_params (dict): Optional model specific parameters, e.g.: |
| 193 | - qwen2_vl_dim (int): Dimension for Qwen2-VL model |
| 194 | - min_hw_dims (int): Minimum HW dimensions |
| 195 | - max_hw_dims (int): Maximum HW dimensions |
| 196 | - num_frames (int): Number of frames for video models |
| 197 | """ |
| 198 | model_params = model_params or {} |
| 199 | onnx_file = f'{onnx_dir}/{onnx_name}' |
| 200 | engine_file = f'{engine_dir}/{engine_name}' |
| 201 | config_file = f'{engine_dir}/config.json' |
| 202 | logger.log(trt.Logger.INFO, f"Building TRT engine to {engine_file}") |
| 203 | |
| 204 | builder = trt.Builder(logger) |
| 205 | network = builder.create_network( |
| 206 | 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) |
| 207 | profile = builder.create_optimization_profile() |
| 208 | |
| 209 | config_args = { |
| 210 | "precision": torch_dtype_to_str(dtype), |
| 211 | "model_type": model_type, |
| 212 | "strongly_typed": False, |
| 213 | "max_batch_size": max_batch_size, |
| 214 | "model_name": "multiModal" |
| 215 | } |
| 216 | |
| 217 | if "num_frames" in model_params: |
| 218 | config_args["num_frames"] = model_params["num_frames"] |
| 219 | |
| 220 | config_wrapper = Builder().create_builder_config(**config_args) |
| 221 | config = config_wrapper.trt_builder_config |
| 222 | |
| 223 | parser = trt.OnnxParser(network, logger) |
| 224 | |
| 225 | with open(onnx_file, 'rb') as model: |
| 226 | if not parser.parse(model.read(), os.path.abspath(onnx_file)): |
| 227 | logger.log(trt.Logger.ERROR, "Failed parsing %s" % onnx_file) |
| 228 | for error in range(parser.num_errors): |
| 229 | logger.log(trt.Logger.ERROR, parser.get_error(error)) |
| 230 | logger.log(trt.Logger.INFO, "Succeeded parsing %s" % onnx_file) |
| 231 | |
| 232 | nBS = -1 |
| 233 | nMinBS = 1 |
| 234 | nOptBS = max(nMinBS, int(max_batch_size / 2)) |
| 235 | nMaxBS = max_batch_size |
no test coverage detected