(self, onnx_opt_path: str, weights_map_path: dict)
| 149 | |
| 150 | # Helper utility for weights map |
| 151 | def export_weights_map(self, onnx_opt_path: str, weights_map_path: dict): |
| 152 | onnx_opt_dir = onnx_opt_path |
| 153 | state_dict = self.unet.state_dict() |
| 154 | onnx_opt_model = onnx.load(onnx_opt_path) |
| 155 | |
| 156 | # Create initializer data hashes |
| 157 | def init_hash_map(onnx_opt_model): |
| 158 | initializer_hash_mapping = {} |
| 159 | for initializer in onnx_opt_model.graph.initializer: |
| 160 | initializer_data = numpy_helper.to_array( |
| 161 | initializer, base_dir=onnx_opt_dir |
| 162 | ).astype(np.float16) |
| 163 | initializer_hash = hash(initializer_data.data.tobytes()) |
| 164 | initializer_hash_mapping[initializer.name] = ( |
| 165 | initializer_hash, |
| 166 | initializer_data.shape, |
| 167 | ) |
| 168 | return initializer_hash_mapping |
| 169 | |
| 170 | initializer_hash_mapping = init_hash_map(onnx_opt_model) |
| 171 | |
| 172 | weights_name_mapping = {} |
| 173 | weights_shape_mapping = {} |
| 174 | # set to keep track of initializers already added to the name_mapping dict |
| 175 | initializers_mapped = set() |
| 176 | for wt_name, wt in state_dict.items(): |
| 177 | # get weight hash |
| 178 | wt = wt.cpu().detach().numpy().astype(np.float16) |
| 179 | wt_hash = hash(wt.data.tobytes()) |
| 180 | wt_t_hash = hash(np.transpose(wt).data.tobytes()) |
| 181 | |
| 182 | for initializer_name, ( |
| 183 | initializer_hash, |
| 184 | initializer_shape, |
| 185 | ) in initializer_hash_mapping.items(): |
| 186 | # Due to constant folding, some weights are transposed during export |
| 187 | # To account for the transpose op, we compare the initializer hash to the |
| 188 | # hash for the weight and its transpose |
| 189 | if wt_hash == initializer_hash or wt_t_hash == initializer_hash: |
| 190 | # The assert below ensures there is a 1:1 mapping between |
| 191 | # PyTorch and ONNX weight names. It can be removed in cases where 1:many |
| 192 | # mapping is found and name_mapping[wt_name] = list() |
| 193 | assert initializer_name not in initializers_mapped |
| 194 | weights_name_mapping[wt_name] = initializer_name |
| 195 | initializers_mapped.add(initializer_name) |
| 196 | is_transpose = False if wt_hash == initializer_hash else True |
| 197 | weights_shape_mapping[wt_name] = ( |
| 198 | initializer_shape, |
| 199 | is_transpose, |
| 200 | ) |
| 201 | |
| 202 | # Sanity check: Were any weights not matched |
| 203 | if wt_name not in weights_name_mapping: |
| 204 | print( |
| 205 | f"[I] PyTorch weight {wt_name} not matched with any ONNX initializer" |
| 206 | ) |
| 207 | print( |
| 208 | f"[I] UNet: {len(weights_name_mapping.keys())} PyTorch weights were matched with ONNX initializers" |
no test coverage detected