| 162 | self.outputs = {} |
| 163 | |
| 164 | def refit_from_dict(self, refit_weights, is_fp16): |
| 165 | # Initialize refitter |
| 166 | refitter = trt.Refitter(self.engine, TRT_LOGGER) |
| 167 | |
| 168 | refitted_weights = set() |
| 169 | # iterate through all tensorrt refittable weights |
| 170 | for trt_weight_name in refitter.get_all_weights(): |
| 171 | if trt_weight_name not in refit_weights: |
| 172 | continue |
| 173 | |
| 174 | # get weight from state dict |
| 175 | trt_datatype = trt.DataType.FLOAT |
| 176 | if is_fp16: |
| 177 | refit_weights[trt_weight_name] = refit_weights[trt_weight_name].half() |
| 178 | trt_datatype = trt.DataType.HALF |
| 179 | |
| 180 | # trt.Weight and trt.TensorLocation |
| 181 | refit_weights[trt_weight_name] = refit_weights[trt_weight_name].cpu() |
| 182 | trt_wt_tensor = trt.Weights( |
| 183 | trt_datatype, |
| 184 | refit_weights[trt_weight_name].data_ptr(), |
| 185 | torch.numel(refit_weights[trt_weight_name]), |
| 186 | ) |
| 187 | trt_wt_location = ( |
| 188 | trt.TensorLocation.DEVICE |
| 189 | if refit_weights[trt_weight_name].is_cuda |
| 190 | else trt.TensorLocation.HOST |
| 191 | ) |
| 192 | |
| 193 | # apply refit |
| 194 | # refitter.set_named_weights(trt_weight_name, trt_wt_tensor, trt_wt_location) |
| 195 | refitter.set_named_weights(trt_weight_name, trt_wt_tensor) |
| 196 | refitted_weights.add(trt_weight_name) |
| 197 | |
| 198 | assert set(refitted_weights) == set(refit_weights.keys()) |
| 199 | if not refitter.refit_cuda_engine(): |
| 200 | print("Error: failed to refit new weights.") |
| 201 | exit(0) |
| 202 | |
| 203 | def build( |
| 204 | self, |