| 91 | |
| 92 | |
| 93 | class Builder(): |
| 94 | |
| 95 | _ALLOWED_PRECISIONS = [ |
| 96 | 'float32', 'float16', 'bfloat16', trt.DataType.HALF, trt.DataType.FLOAT, |
| 97 | trt.DataType.BF16 |
| 98 | ] |
| 99 | |
| 100 | def __init__(self): |
| 101 | super().__init__() |
| 102 | self._trt_builder = trt.Builder(logger.trt_logger) |
| 103 | self.strongly_typed = True |
| 104 | |
| 105 | @property |
| 106 | def trt_builder(self) -> trt.Builder: |
| 107 | return self._trt_builder |
| 108 | |
| 109 | def create_network(self) -> Network: |
| 110 | explicit_batch_flag = 0 |
| 111 | # Explicit batch flag will be deprecated in TRT 10 |
| 112 | if "EXPLICIT_BATCH" in trt.NetworkDefinitionCreationFlag.__members__.keys( |
| 113 | ): |
| 114 | explicit_batch_flag = 1 << int( |
| 115 | trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) |
| 116 | |
| 117 | if self.strongly_typed: |
| 118 | return Network()._init( |
| 119 | self.trt_builder.create_network( |
| 120 | explicit_batch_flag |
| 121 | | (1 << int( |
| 122 | trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)))) |
| 123 | else: |
| 124 | return Network()._init( |
| 125 | self.trt_builder.create_network(explicit_batch_flag)) |
| 126 | |
| 127 | def create_builder_config(self, |
| 128 | precision: Union[str, trt.DataType], |
| 129 | timing_cache: Union[str, Path, |
| 130 | trt.ITimingCache] = None, |
| 131 | tensor_parallel: int = 1, |
| 132 | use_refit: bool = False, |
| 133 | int8: bool = False, |
| 134 | strongly_typed: bool = True, |
| 135 | force_num_profiles: Optional[int] = None, |
| 136 | profiling_verbosity: str = "layer_names_only", |
| 137 | use_strip_plan: bool = False, |
| 138 | weight_streaming: bool = False, |
| 139 | precision_constraints: Optional[str] = "obey", |
| 140 | **kwargs) -> BuilderConfig: |
| 141 | ''' @brief Create a builder config with given precisions and timing cache |
| 142 | @param precision: one of allowed precisions, defined in Builder._ALLOWED_PRECISIONS |
| 143 | @param timing_cache: a timing cache object or a path to a timing cache file |
| 144 | @param tensor_parallel: number of GPUs used for tensor parallel |
| 145 | @param kwargs: any other arguments users would like to attach to the config object as attributes |
| 146 | @param refit: set to accelerate multi-gpu building, build engine for 1 gpu and refit for the others |
| 147 | @param int8: whether to build with int8 enabled or not. Can't be used together with refit option |
| 148 | @return: A BuilderConfig object, return None if failed |
| 149 | ''' |
| 150 | self.strongly_typed = self.strongly_typed and strongly_typed |
no outgoing calls