MCPcopy Index your code
hub / github.com/NVIDIA/Stable-Diffusion-WebUI-TensorRT / export_weights_map

Method export_weights_map

model_helper.py:151–213  ·  view source on GitHub ↗
(self, onnx_opt_path: str, weights_map_path: dict)

Source from the content-addressed store, hash-verified

149
150 # Helper utility for weights map
151 def export_weights_map(self, onnx_opt_path: str, weights_map_path: dict):
152 onnx_opt_dir = onnx_opt_path
153 state_dict = self.unet.state_dict()
154 onnx_opt_model = onnx.load(onnx_opt_path)
155
156 # Create initializer data hashes
157 def init_hash_map(onnx_opt_model):
158 initializer_hash_mapping = {}
159 for initializer in onnx_opt_model.graph.initializer:
160 initializer_data = numpy_helper.to_array(
161 initializer, base_dir=onnx_opt_dir
162 ).astype(np.float16)
163 initializer_hash = hash(initializer_data.data.tobytes())
164 initializer_hash_mapping[initializer.name] = (
165 initializer_hash,
166 initializer_data.shape,
167 )
168 return initializer_hash_mapping
169
170 initializer_hash_mapping = init_hash_map(onnx_opt_model)
171
172 weights_name_mapping = {}
173 weights_shape_mapping = {}
174 # set to keep track of initializers already added to the name_mapping dict
175 initializers_mapped = set()
176 for wt_name, wt in state_dict.items():
177 # get weight hash
178 wt = wt.cpu().detach().numpy().astype(np.float16)
179 wt_hash = hash(wt.data.tobytes())
180 wt_t_hash = hash(np.transpose(wt).data.tobytes())
181
182 for initializer_name, (
183 initializer_hash,
184 initializer_shape,
185 ) in initializer_hash_mapping.items():
186 # Due to constant folding, some weights are transposed during export
187 # To account for the transpose op, we compare the initializer hash to the
188 # hash for the weight and its transpose
189 if wt_hash == initializer_hash or wt_t_hash == initializer_hash:
190 # The assert below ensures there is a 1:1 mapping between
191 # PyTorch and ONNX weight names. It can be removed in cases where 1:many
192 # mapping is found and name_mapping[wt_name] = list()
193 assert initializer_name not in initializers_mapped
194 weights_name_mapping[wt_name] = initializer_name
195 initializers_mapped.add(initializer_name)
196 is_transpose = False if wt_hash == initializer_hash else True
197 weights_shape_mapping[wt_name] = (
198 initializer_shape,
199 is_transpose,
200 )
201
202 # Sanity check: Were any weights not matched
203 if wt_name not in weights_name_mapping:
204 print(
205 f"[I] PyTorch weight {wt_name} not matched with any ONNX initializer"
206 )
207 print(
208 f"[I] UNet: {len(weights_name_mapping.keys())} PyTorch weights were matched with ONNX initializers"

Callers 1

export_lora_to_trtFunction · 0.95

Calls 1

loadMethod · 0.80

Tested by

no test coverage detected