(model_dir, dtype, out_dir)
| 130 | |
| 131 | |
| 132 | def convert_hf_model(model_dir, dtype, out_dir): |
| 133 | saved_dir = Path(out_dir) |
| 134 | saved_dir.mkdir(parents=True, exist_ok=True) |
| 135 | with open(f"{model_dir}/adapter_config.json", "r") as f: |
| 136 | config = json.load(f) |
| 137 | |
| 138 | alpha = config.get("lora_alpha") |
| 139 | use_rslora = config.get("use_rslora", False) |
| 140 | |
| 141 | lora_model = load_state_dict(get_model_path(model_dir, "adapter_model")) |
| 142 | lora_model = preprocess_lora_weights(lora_model) |
| 143 | all_weights = get_all_lora_weights(lora_model) |
| 144 | converted_weights = [] |
| 145 | converted_config = [] |
| 146 | |
| 147 | def derive_adapter_size(inout_weight: torch.Tensor) -> int: |
| 148 | assert len(inout_weight.shape) == 2 |
| 149 | dim0, dim1 = inout_weight.shape |
| 150 | # assume the hidden dim is the larger of the 2 |
| 151 | adapter_size = min(dim0, dim1) |
| 152 | return adapter_size |
| 153 | |
| 154 | def derive_weights_scale(adapter_size: int, alpha: float, |
| 155 | use_rslora: bool) -> float: |
| 156 | if use_rslora: |
| 157 | return alpha / np.sqrt(adapter_size) |
| 158 | return alpha / adapter_size |
| 159 | |
| 160 | for layer_idx, layer_weights in all_weights.items(): |
| 161 | for hf_module, module_weights in layer_weights.items(): |
| 162 | in_weights = module_weights['in'] |
| 163 | out_weights = module_weights['out'] |
| 164 | magnitude = module_weights.get("magnitude", None) |
| 165 | is_dora = magnitude is not None |
| 166 | |
| 167 | processed_weights = [] |
| 168 | |
| 169 | assert len(in_weights.shape) == 2 |
| 170 | assert len(out_weights.shape) == 2 |
| 171 | assert not is_dora or len(magnitude.shape) == 1 |
| 172 | |
| 173 | adapter_size = derive_adapter_size(in_weights) |
| 174 | assert adapter_size == derive_adapter_size( |
| 175 | out_weights), "adapter size of A mismatches adapter size of B" |
| 176 | scale = derive_weights_scale(adapter_size, alpha, use_rslora) |
| 177 | |
| 178 | for w, inout in ((in_weights, "in"), (out_weights, "out")): |
| 179 | dim0 = w.shape[0] |
| 180 | dim1 = w.shape[1] |
| 181 | # in_weights should have shape [adaper_size, hidden] |
| 182 | if dim1 < dim0 and inout == "in": |
| 183 | w = w.transpose(1, 0) |
| 184 | # out_weights should have shape [hidden, adapter_size] |
| 185 | elif dim0 < dim1 and inout == "out": |
| 186 | w = w.transpose(1, 0) |
| 187 | if inout == "out": |
| 188 | w = w * scale |
| 189 | w = w.contiguous().flatten().to(dtype=str_dtype_to_torch(dtype)) |
no test coverage detected