| 49 | |
| 50 | |
| 51 | def get_refit_weights( |
| 52 | state_dict: dict, onnx_opt_path: str, weight_name_mapping: dict, weight_shape_mapping: dict |
| 53 | ) -> dict: |
| 54 | refit_weights = OrderedDict() |
| 55 | onnx_opt_dir = os.path.dirname(onnx_opt_path) |
| 56 | onnx_opt_model = onnx.load(onnx_opt_path) |
| 57 | # Create initializer data hashes |
| 58 | initializer_hash_mapping = {} |
| 59 | onnx_data_mapping = {} |
| 60 | for initializer in onnx_opt_model.graph.initializer: |
| 61 | initializer_data = numpy_helper.to_array( |
| 62 | initializer, base_dir=onnx_opt_dir |
| 63 | ).astype(np.float16) |
| 64 | initializer_hash = hash(initializer_data.data.tobytes()) |
| 65 | initializer_hash_mapping[initializer.name] = initializer_hash |
| 66 | onnx_data_mapping[initializer.name] = initializer_data |
| 67 | |
| 68 | for torch_name, initializer_name in weight_name_mapping.items(): |
| 69 | initializer_hash = initializer_hash_mapping[initializer_name] |
| 70 | wt = state_dict[torch_name] |
| 71 | |
| 72 | # get shape transform info |
| 73 | initializer_shape, is_transpose = weight_shape_mapping[torch_name] |
| 74 | if is_transpose: |
| 75 | wt = torch.transpose(wt, 0, 1) |
| 76 | else: |
| 77 | wt = torch.reshape(wt, initializer_shape) |
| 78 | |
| 79 | # include weight if hashes differ |
| 80 | wt_hash = hash(wt.cpu().detach().numpy().astype(np.float16).data.tobytes()) |
| 81 | if initializer_hash != wt_hash: |
| 82 | delta = wt - torch.tensor(onnx_data_mapping[initializer_name]).to(wt.device) |
| 83 | refit_weights[initializer_name] = delta.contiguous() |
| 84 | |
| 85 | return refit_weights |
| 86 | |
| 87 | |
| 88 | def export_lora( |