MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / refit_engine

Function refit_engine

tensorrt_llm/commands/refit.py:25–106  ·  view source on GitHub ↗
(engine_path: str, refit_engine_dir: str, checkpoint_dir: str,
                 engine_config: EngineConfig, fixed_weights_names: list)

Source from the content-addressed store, hash-verified

23
24@_is_building
25def refit_engine(engine_path: str, refit_engine_dir: str, checkpoint_dir: str,
26 engine_config: EngineConfig, fixed_weights_names: list):
27 # This function loops through all weights in the model and does a textual match between
28 # checkpoint weight names and engine weight names.
29 rank = int(ENGINE_RE.fullmatch(os.path.basename(engine_path)).group(1))
30 tik = time.time()
31 with open(engine_path,
32 "rb") as f, trt.Runtime(logger.trt_logger) as runtime:
33 engine = runtime.deserialize_cuda_engine(f.read())
34 tok = time.time()
35 t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
36 logger.info(f'Load TRT engine time: {t}')
37
38 refitter = trt.Refitter(engine, logger.trt_logger)
39 refittable_weights = set(refitter.get_all_weights())
40
41 # Load model.
42 tik = time.time()
43 rank_config = PretrainedConfig.from_dict(
44 engine_config.pretrained_config.to_dict())
45 rank_config.set_rank(rank)
46
47 architecture = rank_config.architecture
48 assert architecture in MODEL_MAP, \
49 f"Unsupported model architecture: {architecture}"
50 model_cls = MODEL_MAP[architecture]
51 model = model_cls.from_checkpoint(checkpoint_dir, config=rank_config)
52
53 tok = time.time()
54 t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
55 logger.info(f'Load checkpoint(s) time: {t}')
56
57 # There are weights preprocess during optimize model.
58 tik = time.time()
59 build_config = engine_config.build_config.model_copy(deep=True)
60 optimize_model_with_config(model, build_config)
61 tok = time.time()
62 t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
63 logger.info(f'Preprocess weights time: {t}')
64
65 # Refit engine.
66 tik = time.time()
67 refitted_weights = []
68 for name, buf in model.named_parameters():
69 if name in refittable_weights:
70 assert buf.is_inited, f"Failed because weight `{name}` is not initialized in model."
71 weight = buf._value
72 weights_value = trt.Weights(np_dtype_to_trt(weight.dtype),
73 weight.ctypes.data, weight.size)
74 assert refitter.set_named_weights(
75 name, weights_value), f'Failed to refit weight: `{name}`'
76 refitted_weights.append(name)
77 else:
78 if name not in fixed_weights_names:
79 logger.warning(
80 f"model weights `{name}` (shape: {buf._value.shape}) is not refittable, this means that we might not be able to update the engine using fine-tuned checkpoint!"
81 )
82

Callers 1

refitFunction · 0.85

Calls 10

np_dtype_to_trtFunction · 0.90
infoMethod · 0.45
from_dictMethod · 0.45
to_dictMethod · 0.45
set_rankMethod · 0.45
from_checkpointMethod · 0.45
named_parametersMethod · 0.45
appendMethod · 0.45
warningMethod · 0.45

Tested by

no test coverage detected