MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / create_builder_config

Method create_builder_config

tensorrt_llm/builder.py:127–236  ·  view source on GitHub ↗

@brief Create a builder config with given precisions and timing cache @param precision: one of allowed precisions, defined in Builder._ALLOWED_PRECISIONS @param timing_cache: a timing cache object or a path to a timing cache file @param tensor_parallel: number of

(self,
                              precision: Union[str, trt.DataType],
                              timing_cache: Union[str, Path,
                                                  trt.ITimingCache] = None,
                              tensor_parallel: int = 1,
                              use_refit: bool = False,
                              int8: bool = False,
                              strongly_typed: bool = True,
                              force_num_profiles: Optional[int] = None,
                              profiling_verbosity: str = "layer_names_only",
                              use_strip_plan: bool = False,
                              weight_streaming: bool = False,
                              precision_constraints: Optional[str] = "obey",
                              **kwargs)

Source from the content-addressed store, hash-verified

125 self.trt_builder.create_network(explicit_batch_flag))
126
127 def create_builder_config(self,
128 precision: Union[str, trt.DataType],
129 timing_cache: Union[str, Path,
130 trt.ITimingCache] = None,
131 tensor_parallel: int = 1,
132 use_refit: bool = False,
133 int8: bool = False,
134 strongly_typed: bool = True,
135 force_num_profiles: Optional[int] = None,
136 profiling_verbosity: str = "layer_names_only",
137 use_strip_plan: bool = False,
138 weight_streaming: bool = False,
139 precision_constraints: Optional[str] = "obey",
140 **kwargs) -> BuilderConfig:
141 ''' @brief Create a builder config with given precisions and timing cache
142 @param precision: one of allowed precisions, defined in Builder._ALLOWED_PRECISIONS
143 @param timing_cache: a timing cache object or a path to a timing cache file
144 @param tensor_parallel: number of GPUs used for tensor parallel
145 @param kwargs: any other arguments users would like to attach to the config object as attributes
146 @param refit: set to accelerate multi-gpu building, build engine for 1 gpu and refit for the others
147 @param int8: whether to build with int8 enabled or not. Can't be used together with refit option
148 @return: A BuilderConfig object, return None if failed
149 '''
150 self.strongly_typed = self.strongly_typed and strongly_typed
151
152 quant_mode = kwargs.get("quant_mode", QuantMode(0))
153 if not strongly_typed and precision not in self._ALLOWED_PRECISIONS:
154 logger.error(
155 f"precision should be one of {self._ALLOWED_PRECISIONS}")
156
157 config = self.trt_builder.create_builder_config()
158 if weight_streaming:
159 config.set_flag(trt.BuilderFlag.WEIGHT_STREAMING)
160 if not self.strongly_typed:
161 fp8 = quant_mode.has_fp8_qdq() or quant_mode.has_fp8_kv_cache()
162 if precision == 'float16' or precision == trt.DataType.HALF:
163 config.set_flag(trt.BuilderFlag.FP16)
164 if precision_constraints == 'obey':
165 config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
166 elif precision == 'bfloat16' or precision == trt.DataType.BF16:
167 config.set_flag(trt.BuilderFlag.BF16)
168 if precision_constraints == 'obey':
169 config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
170 if int8:
171 config.set_flag(trt.BuilderFlag.INT8)
172 if fp8:
173 config.set_flag(trt.BuilderFlag.FP8)
174 if precision_constraints == 'obey':
175 config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
176
177 if use_refit:
178 config.set_flag(trt.BuilderFlag.REFIT)
179
180 # Use fine-grained refit when strip plan is enabled in TRT10.2+.
181 if use_strip_plan:
182 config.set_flag(trt.BuilderFlag.REFIT_INDIVIDUAL)
183
184 if use_strip_plan:

Callers 15

build_bertFunction · 0.95
_construct_executionMethod · 0.95
_construct_executionMethod · 0.95
build_engineMethod · 0.95
test_fp4_gemmMethod · 0.95

Calls 9

QuantModeClass · 0.85
BuilderConfigClass · 0.85
has_fp8_kv_cacheMethod · 0.80
trt_gteFunction · 0.70
getMethod · 0.45
errorMethod · 0.45
has_fp8_qdqMethod · 0.45
warningMethod · 0.45
_initMethod · 0.45

Tested by 15

_construct_executionMethod · 0.76
_construct_executionMethod · 0.76
build_engineMethod · 0.76
test_fp4_gemmMethod · 0.76