MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / build_trt_engine

Function build_trt_engine

tensorrt_llm/tools/multimodal_builder.py:178–355  ·  view source on GitHub ↗

Build TensorRT engine from ONNX model. Args: model_params (dict): Optional model specific parameters, e.g.: - qwen2_vl_dim (int): Dimension for Qwen2-VL model - min_hw_dims (int): Minimum HW dimensions - max_hw_dims (int): Maximum HW dimensions

(model_type,
                     input_sizes,
                     onnx_dir,
                     engine_dir,
                     max_batch_size,
                     dtype=torch.float16,
                     model_params=None,
                     onnx_name='model.onnx',
                     engine_name='model.engine',
                     delete_onnx=True,
                     logger=trt.Logger(trt.Logger.INFO))

Source from the content-addressed store, hash-verified

176
177
178def build_trt_engine(model_type,
179 input_sizes,
180 onnx_dir,
181 engine_dir,
182 max_batch_size,
183 dtype=torch.float16,
184 model_params=None,
185 onnx_name='model.onnx',
186 engine_name='model.engine',
187 delete_onnx=True,
188 logger=trt.Logger(trt.Logger.INFO)):
189 """Build TensorRT engine from ONNX model.
190
191 Args:
192 model_params (dict): Optional model specific parameters, e.g.:
193 - qwen2_vl_dim (int): Dimension for Qwen2-VL model
194 - min_hw_dims (int): Minimum HW dimensions
195 - max_hw_dims (int): Maximum HW dimensions
196 - num_frames (int): Number of frames for video models
197 """
198 model_params = model_params or {}
199 onnx_file = f'{onnx_dir}/{onnx_name}'
200 engine_file = f'{engine_dir}/{engine_name}'
201 config_file = f'{engine_dir}/config.json'
202 logger.log(trt.Logger.INFO, f"Building TRT engine to {engine_file}")
203
204 builder = trt.Builder(logger)
205 network = builder.create_network(
206 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
207 profile = builder.create_optimization_profile()
208
209 config_args = {
210 "precision": torch_dtype_to_str(dtype),
211 "model_type": model_type,
212 "strongly_typed": False,
213 "max_batch_size": max_batch_size,
214 "model_name": "multiModal"
215 }
216
217 if "num_frames" in model_params:
218 config_args["num_frames"] = model_params["num_frames"]
219
220 config_wrapper = Builder().create_builder_config(**config_args)
221 config = config_wrapper.trt_builder_config
222
223 parser = trt.OnnxParser(network, logger)
224
225 with open(onnx_file, 'rb') as model:
226 if not parser.parse(model.read(), os.path.abspath(onnx_file)):
227 logger.log(trt.Logger.ERROR, "Failed parsing %s" % onnx_file)
228 for error in range(parser.num_errors):
229 logger.log(trt.Logger.ERROR, parser.get_error(error))
230 logger.log(trt.Logger.INFO, "Succeeded parsing %s" % onnx_file)
231
232 nBS = -1
233 nMinBS = 1
234 nOptBS = max(nMinBS, int(max_batch_size / 2))
235 nMaxBS = max_batch_size

Callers 15

build_blip2_engineFunction · 0.85
build_pix2struct_engineFunction · 0.85
build_llava_engineFunction · 0.85
build_vila_engineFunction · 0.85
build_nougat_engineFunction · 0.85
build_cogvlm_engineFunction · 0.85
build_fuyu_engineFunction · 0.85
build_neva_engineFunction · 0.85
build_video_neva_engineFunction · 0.85
build_kosmos_engineFunction · 0.85
build_phi_engineFunction · 0.85

Calls 13

create_networkMethod · 0.95
torch_dtype_to_strFunction · 0.90
BuilderClass · 0.90
maxFunction · 0.85
create_builder_configMethod · 0.80
get_errorMethod · 0.80
save_configMethod · 0.80
logMethod · 0.45
parseMethod · 0.45
get_inputMethod · 0.45
getMethod · 0.45

Tested by

no test coverage detected