MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / get_engine_name

Function get_engine_name

tensorrt_llm/runtime/model_runner.py:41–64  ·  view source on GitHub ↗

Get the serialized engine file name. Args: model (str): Model name, e.g., bloom, gpt. dtype (str): Data type, e.g., float32, float16, bfloat16, tp_size (int): The size of tensor parallel. pp_size (int): The siz

(model: str, dtype: str, tp_size: int, pp_size: int,
                    rank: int)

Source from the content-addressed store, hash-verified

39
40
41def get_engine_name(model: str, dtype: str, tp_size: int, pp_size: int,
42 rank: int) -> str:
43 """
44 Get the serialized engine file name.
45
46 Args:
47 model (str):
48 Model name, e.g., bloom, gpt.
49 dtype (str):
50 Data type, e.g., float32, float16, bfloat16,
51 tp_size (int):
52 The size of tensor parallel.
53 pp_size (int):
54 The size of pipeline parallel.
55 rank (int):
56 The rank id.
57
58 Returns:
59 str: The serialized engine file name.
60 """
61 if pp_size == 1:
62 return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank)
63 return '{}_{}_tp{}_pp{}_rank{}.engine'.format(model, dtype, tp_size,
64 pp_size, rank)
65
66
67def read_config(config_path: Path) -> Tuple[ModelConfig, dict]:

Callers 1

from_dirMethod · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected