@brief: Refit one TensorRT engine using weights from the network, user should guarantee that the engine is built with REFIT flag, and the network has the same structure with the engine. @param engine_buffer: A serialized TensorRT engine. @param ne
(self, network: Network, engine_buffer)
| 312 | |
| 313 | @_is_building |
| 314 | def refit_engine(self, network: Network, engine_buffer) -> trt.IHostMemory: |
| 315 | ''' |
| 316 | @brief: Refit one TensorRT engine using weights from the network, |
| 317 | user should guarantee that the engine is built with REFIT flag, and the network has the same structure with the engine. |
| 318 | @param engine_buffer: A serialized TensorRT engine. |
| 319 | @param network: Network object. |
| 320 | @return: A serialized TRT engine if refit successfully, None otherwise |
| 321 | ''' |
| 322 | assert isinstance(network, Network) |
| 323 | logger.info('Refit TRT engine') |
| 324 | runtime = trt.Runtime(logger.trt_logger) |
| 325 | engine = runtime.deserialize_cuda_engine(engine_buffer) |
| 326 | |
| 327 | tik = time.time() |
| 328 | |
| 329 | # Refit engine |
| 330 | refitter = trt.Refitter(engine, logger.trt_logger) |
| 331 | if network.named_parameters is not None: |
| 332 | for name, param in network.named_parameters: |
| 333 | if param._get_weights( |
| 334 | ) is None or not refitter.set_named_weights( |
| 335 | name, param._get_weights()): |
| 336 | logger.error(f'Failed to refit weight: {name}') |
| 337 | return None |
| 338 | else: |
| 339 | logger.error( |
| 340 | 'Please set named parameters before building multiple engines.') |
| 341 | return None |
| 342 | |
| 343 | if not refitter.refit_cuda_engine(): |
| 344 | logger.error('Failed to refit engine.') |
| 345 | return None |
| 346 | |
| 347 | tok = time.time() |
| 348 | t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) |
| 349 | logger.info(f'Total time of refitting {engine.name}: {t}') |
| 350 | serialized_engine = engine.serialize() |
| 351 | return serialized_engine |
| 352 | |
| 353 | @_is_building |
| 354 | def build_engine(self, |