MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / refit_engine

Method refit_engine

tensorrt_llm/builder.py:314–351  ·  view source on GitHub ↗

@brief: Refit one TensorRT engine using weights from the network, user should guarantee that the engine is built with REFIT flag, and the network has the same structure with the engine. @param engine_buffer: A serialized TensorRT engine. @param ne

(self, network: Network, engine_buffer)

Source from the content-addressed store, hash-verified

312
313 @_is_building
314 def refit_engine(self, network: Network, engine_buffer) -> trt.IHostMemory:
315 '''
316 @brief: Refit one TensorRT engine using weights from the network,
317 user should guarantee that the engine is built with REFIT flag, and the network has the same structure with the engine.
318 @param engine_buffer: A serialized TensorRT engine.
319 @param network: Network object.
320 @return: A serialized TRT engine if refit successfully, None otherwise
321 '''
322 assert isinstance(network, Network)
323 logger.info('Refit TRT engine')
324 runtime = trt.Runtime(logger.trt_logger)
325 engine = runtime.deserialize_cuda_engine(engine_buffer)
326
327 tik = time.time()
328
329 # Refit engine
330 refitter = trt.Refitter(engine, logger.trt_logger)
331 if network.named_parameters is not None:
332 for name, param in network.named_parameters:
333 if param._get_weights(
334 ) is None or not refitter.set_named_weights(
335 name, param._get_weights()):
336 logger.error(f'Failed to refit weight: {name}')
337 return None
338 else:
339 logger.error(
340 'Please set named parameters before building multiple engines.')
341 return None
342
343 if not refitter.refit_cuda_engine():
344 logger.error('Failed to refit engine.')
345 return None
346
347 tok = time.time()
348 t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
349 logger.info(f'Total time of refitting {engine.name}: {t}')
350 serialized_engine = engine.serialize()
351 return serialized_engine
352
353 @_is_building
354 def build_engine(self,

Callers 1

Calls 4

_get_weightsMethod · 0.80
infoMethod · 0.45
errorMethod · 0.45
serializeMethod · 0.45

Tested by 1