MCPcopy Index your code
hub / github.com/NVIDIA/Stable-Diffusion-WebUI-TensorRT / refit_from_dict

Method refit_from_dict

utilities.py:164–201  ·  view source on GitHub ↗
(self, refit_weights, is_fp16)

Source from the content-addressed store, hash-verified

162 self.outputs = {}
163
164 def refit_from_dict(self, refit_weights, is_fp16):
165 # Initialize refitter
166 refitter = trt.Refitter(self.engine, TRT_LOGGER)
167
168 refitted_weights = set()
169 # iterate through all tensorrt refittable weights
170 for trt_weight_name in refitter.get_all_weights():
171 if trt_weight_name not in refit_weights:
172 continue
173
174 # get weight from state dict
175 trt_datatype = trt.DataType.FLOAT
176 if is_fp16:
177 refit_weights[trt_weight_name] = refit_weights[trt_weight_name].half()
178 trt_datatype = trt.DataType.HALF
179
180 # trt.Weight and trt.TensorLocation
181 refit_weights[trt_weight_name] = refit_weights[trt_weight_name].cpu()
182 trt_wt_tensor = trt.Weights(
183 trt_datatype,
184 refit_weights[trt_weight_name].data_ptr(),
185 torch.numel(refit_weights[trt_weight_name]),
186 )
187 trt_wt_location = (
188 trt.TensorLocation.DEVICE
189 if refit_weights[trt_weight_name].is_cuda
190 else trt.TensorLocation.HOST
191 )
192
193 # apply refit
194 # refitter.set_named_weights(trt_weight_name, trt_wt_tensor, trt_wt_location)
195 refitter.set_named_weights(trt_weight_name, trt_wt_tensor)
196 refitted_weights.add(trt_weight_name)
197
198 assert set(refitted_weights) == set(refit_weights.keys())
199 if not refitter.refit_cuda_engine():
200 print("Error: failed to refit new weights.")
201 exit(0)
202
203 def build(
204 self,

Callers 1

apply_lorasMethod · 0.80

Calls

no outgoing calls

Tested by

no test coverage detected