MCPcopy Index your code
hub / github.com/NVIDIA/Stable-Diffusion-WebUI-TensorRT / Engine

Class Engine

utilities.py:134–331  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

132
133
134class Engine:
135 def __init__(
136 self,
137 engine_path,
138 ):
139 self.engine_path = engine_path
140 self.engine = None
141 self.context = None
142 self.buffers = OrderedDict()
143 self.tensors = OrderedDict()
144 self.cuda_graph_instance = None # cuda graph
145
146 def __del__(self):
147 del self.engine
148 del self.context
149 del self.buffers
150 del self.tensors
151
152 def reset(self, engine_path=None):
153 del self.engine
154 del self.context
155 del self.buffers
156 del self.tensors
157 self.engine_path = engine_path
158
159 self.buffers = OrderedDict()
160 self.tensors = OrderedDict()
161 self.inputs = {}
162 self.outputs = {}
163
164 def refit_from_dict(self, refit_weights, is_fp16):
165 # Initialize refitter
166 refitter = trt.Refitter(self.engine, TRT_LOGGER)
167
168 refitted_weights = set()
169 # iterate through all tensorrt refittable weights
170 for trt_weight_name in refitter.get_all_weights():
171 if trt_weight_name not in refit_weights:
172 continue
173
174 # get weight from state dict
175 trt_datatype = trt.DataType.FLOAT
176 if is_fp16:
177 refit_weights[trt_weight_name] = refit_weights[trt_weight_name].half()
178 trt_datatype = trt.DataType.HALF
179
180 # trt.Weight and trt.TensorLocation
181 refit_weights[trt_weight_name] = refit_weights[trt_weight_name].cpu()
182 trt_wt_tensor = trt.Weights(
183 trt_datatype,
184 refit_weights[trt_weight_name].data_ptr(),
185 torch.numel(refit_weights[trt_weight_name]),
186 )
187 trt_wt_location = (
188 trt.TensorLocation.DEVICE
189 if refit_weights[trt_weight_name].is_cuda
190 else trt.TensorLocation.HOST
191 )

Callers 2

export_trtFunction · 0.90
activateMethod · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected