@brief: Build one TensorRT engine from the network. @param network: Network object. @param builder_config: BuilderConfig object. @return: A serialized TRT engine.
(self,
network: Network,
builder_config: BuilderConfig,
managed_weights: dict = None)
| 352 | |
| 353 | @_is_building |
| 354 | def build_engine(self, |
| 355 | network: Network, |
| 356 | builder_config: BuilderConfig, |
| 357 | managed_weights: dict = None) -> trt.IHostMemory: |
| 358 | ''' |
| 359 | @brief: Build one TensorRT engine from the network. |
| 360 | @param network: Network object. |
| 361 | @param builder_config: BuilderConfig object. |
| 362 | @return: A serialized TRT engine. |
| 363 | ''' |
| 364 | assert isinstance(network, Network) |
| 365 | builder_config.plugin_config = network.plugin_config |
| 366 | if builder_config.trt_builder_config.num_optimization_profiles == 0: |
| 367 | self._add_optimization_profile(network, builder_config) |
| 368 | logger.info( |
| 369 | f"Total optimization profiles added: {builder_config.trt_builder_config.num_optimization_profiles}" |
| 370 | ) |
| 371 | engine = None |
| 372 | |
| 373 | tik = time.time() |
| 374 | # Rename weights |
| 375 | if network.named_parameters is not None: |
| 376 | managed_parameters = [] |
| 377 | for name, param in network.named_parameters: |
| 378 | if param.is_managed(network): |
| 379 | assert managed_weights is not None, "managed_weights should be provided when enabled" |
| 380 | managed_parameters.append(param) |
| 381 | param.set_name(name, network) |
| 382 | continue |
| 383 | if param._get_weights(network) is None: |
| 384 | if not param.is_buffer: |
| 385 | logger.debug( |
| 386 | f"Parameter {name} {param.raw_value.shape} {param.raw_value.dtype} was created" |
| 387 | " but unused in forward method, so not materialized to TRT network" |
| 388 | ) |
| 389 | continue |
| 390 | if not param.set_name(name, network): |
| 391 | raise RuntimeError(f'Failed to set weight: {name}') |
| 392 | # This mark_weights_refittable has no side effect when refit_individual is not enabled. |
| 393 | network.trt_network.mark_weights_refittable(name) |
| 394 | |
| 395 | network._fill_weights() |
| 396 | tok = time.time() |
| 397 | t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) |
| 398 | logger.info( |
| 399 | f'Total time to initialize the weights in network {network.trt_network.name}: {t}' |
| 400 | ) |
| 401 | |
| 402 | # Build engine |
| 403 | logger.info(f'Build TensorRT engine {network.trt_network.name}') |
| 404 | tik = time.time() |
| 405 | engine = self.trt_builder.build_serialized_network( |
| 406 | network.trt_network, builder_config.trt_builder_config) |
| 407 | assert engine is not None, 'Engine building failed, please check the error log.' |
| 408 | |
| 409 | tok = time.time() |
| 410 | t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) |
| 411 | logger.info(f'Total time of building {network.trt_network.name}: {t}') |