| 132 | |
| 133 | |
| 134 | class Engine: |
| 135 | def __init__( |
| 136 | self, |
| 137 | engine_path, |
| 138 | ): |
| 139 | self.engine_path = engine_path |
| 140 | self.engine = None |
| 141 | self.context = None |
| 142 | self.buffers = OrderedDict() |
| 143 | self.tensors = OrderedDict() |
| 144 | self.cuda_graph_instance = None # cuda graph |
| 145 | |
| 146 | def __del__(self): |
| 147 | del self.engine |
| 148 | del self.context |
| 149 | del self.buffers |
| 150 | del self.tensors |
| 151 | |
| 152 | def reset(self, engine_path=None): |
| 153 | del self.engine |
| 154 | del self.context |
| 155 | del self.buffers |
| 156 | del self.tensors |
| 157 | self.engine_path = engine_path |
| 158 | |
| 159 | self.buffers = OrderedDict() |
| 160 | self.tensors = OrderedDict() |
| 161 | self.inputs = {} |
| 162 | self.outputs = {} |
| 163 | |
| 164 | def refit_from_dict(self, refit_weights, is_fp16): |
| 165 | # Initialize refitter |
| 166 | refitter = trt.Refitter(self.engine, TRT_LOGGER) |
| 167 | |
| 168 | refitted_weights = set() |
| 169 | # iterate through all tensorrt refittable weights |
| 170 | for trt_weight_name in refitter.get_all_weights(): |
| 171 | if trt_weight_name not in refit_weights: |
| 172 | continue |
| 173 | |
| 174 | # get weight from state dict |
| 175 | trt_datatype = trt.DataType.FLOAT |
| 176 | if is_fp16: |
| 177 | refit_weights[trt_weight_name] = refit_weights[trt_weight_name].half() |
| 178 | trt_datatype = trt.DataType.HALF |
| 179 | |
| 180 | # trt.Weight and trt.TensorLocation |
| 181 | refit_weights[trt_weight_name] = refit_weights[trt_weight_name].cpu() |
| 182 | trt_wt_tensor = trt.Weights( |
| 183 | trt_datatype, |
| 184 | refit_weights[trt_weight_name].data_ptr(), |
| 185 | torch.numel(refit_weights[trt_weight_name]), |
| 186 | ) |
| 187 | trt_wt_location = ( |
| 188 | trt.TensorLocation.DEVICE |
| 189 | if refit_weights[trt_weight_name].is_cuda |
| 190 | else trt.TensorLocation.HOST |
| 191 | ) |
no outgoing calls
no test coverage detected