@brief Create a builder config with given precisions and timing cache @param precision: one of allowed precisions, defined in Builder._ALLOWED_PRECISIONS @param timing_cache: a timing cache object or a path to a timing cache file @param tensor_parallel: number of
(self,
precision: Union[str, trt.DataType],
timing_cache: Union[str, Path,
trt.ITimingCache] = None,
tensor_parallel: int = 1,
use_refit: bool = False,
int8: bool = False,
strongly_typed: bool = True,
force_num_profiles: Optional[int] = None,
profiling_verbosity: str = "layer_names_only",
use_strip_plan: bool = False,
weight_streaming: bool = False,
precision_constraints: Optional[str] = "obey",
**kwargs)
| 125 | self.trt_builder.create_network(explicit_batch_flag)) |
| 126 | |
| 127 | def create_builder_config(self, |
| 128 | precision: Union[str, trt.DataType], |
| 129 | timing_cache: Union[str, Path, |
| 130 | trt.ITimingCache] = None, |
| 131 | tensor_parallel: int = 1, |
| 132 | use_refit: bool = False, |
| 133 | int8: bool = False, |
| 134 | strongly_typed: bool = True, |
| 135 | force_num_profiles: Optional[int] = None, |
| 136 | profiling_verbosity: str = "layer_names_only", |
| 137 | use_strip_plan: bool = False, |
| 138 | weight_streaming: bool = False, |
| 139 | precision_constraints: Optional[str] = "obey", |
| 140 | **kwargs) -> BuilderConfig: |
| 141 | ''' @brief Create a builder config with given precisions and timing cache |
| 142 | @param precision: one of allowed precisions, defined in Builder._ALLOWED_PRECISIONS |
| 143 | @param timing_cache: a timing cache object or a path to a timing cache file |
| 144 | @param tensor_parallel: number of GPUs used for tensor parallel |
| 145 | @param kwargs: any other arguments users would like to attach to the config object as attributes |
| 146 | @param refit: set to accelerate multi-gpu building, build engine for 1 gpu and refit for the others |
| 147 | @param int8: whether to build with int8 enabled or not. Can't be used together with refit option |
| 148 | @return: A BuilderConfig object, return None if failed |
| 149 | ''' |
| 150 | self.strongly_typed = self.strongly_typed and strongly_typed |
| 151 | |
| 152 | quant_mode = kwargs.get("quant_mode", QuantMode(0)) |
| 153 | if not strongly_typed and precision not in self._ALLOWED_PRECISIONS: |
| 154 | logger.error( |
| 155 | f"precision should be one of {self._ALLOWED_PRECISIONS}") |
| 156 | |
| 157 | config = self.trt_builder.create_builder_config() |
| 158 | if weight_streaming: |
| 159 | config.set_flag(trt.BuilderFlag.WEIGHT_STREAMING) |
| 160 | if not self.strongly_typed: |
| 161 | fp8 = quant_mode.has_fp8_qdq() or quant_mode.has_fp8_kv_cache() |
| 162 | if precision == 'float16' or precision == trt.DataType.HALF: |
| 163 | config.set_flag(trt.BuilderFlag.FP16) |
| 164 | if precision_constraints == 'obey': |
| 165 | config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS) |
| 166 | elif precision == 'bfloat16' or precision == trt.DataType.BF16: |
| 167 | config.set_flag(trt.BuilderFlag.BF16) |
| 168 | if precision_constraints == 'obey': |
| 169 | config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS) |
| 170 | if int8: |
| 171 | config.set_flag(trt.BuilderFlag.INT8) |
| 172 | if fp8: |
| 173 | config.set_flag(trt.BuilderFlag.FP8) |
| 174 | if precision_constraints == 'obey': |
| 175 | config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS) |
| 176 | |
| 177 | if use_refit: |
| 178 | config.set_flag(trt.BuilderFlag.REFIT) |
| 179 | |
| 180 | # Use fine-grained refit when strip plan is enabled in TRT10.2+. |
| 181 | if use_strip_plan: |
| 182 | config.set_flag(trt.BuilderFlag.REFIT_INDIVIDUAL) |
| 183 | |
| 184 | if use_strip_plan: |