(
rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None
)
| 151 | |
| 152 | |
| 153 | def get_peft_kwargs( |
| 154 | rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None |
| 155 | ): |
| 156 | rank_pattern = {} |
| 157 | alpha_pattern = {} |
| 158 | r = lora_alpha = list(rank_dict.values())[0] |
| 159 | |
| 160 | if len(set(rank_dict.values())) > 1: |
| 161 | # get the rank occurring the most number of times |
| 162 | r = collections.Counter(rank_dict.values()).most_common()[0][0] |
| 163 | |
| 164 | # for modules with rank different from the most occurring rank, add it to the `rank_pattern` |
| 165 | rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items())) |
| 166 | rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()} |
| 167 | |
| 168 | if network_alpha_dict is not None and len(network_alpha_dict) > 0: |
| 169 | if len(set(network_alpha_dict.values())) > 1: |
| 170 | # get the alpha occurring the most number of times |
| 171 | lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] |
| 172 | |
| 173 | # for modules with alpha different from the most occurring alpha, add it to the `alpha_pattern` |
| 174 | alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items())) |
| 175 | if is_unet: |
| 176 | alpha_pattern = { |
| 177 | ".".join(k.split(".lora_A.")[0].split(".")).replace(".alpha", ""): v |
| 178 | for k, v in alpha_pattern.items() |
| 179 | } |
| 180 | else: |
| 181 | alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()} |
| 182 | else: |
| 183 | lora_alpha = set(network_alpha_dict.values()).pop() |
| 184 | |
| 185 | target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) |
| 186 | use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict) |
| 187 | # for now we know that the "bias" keys are only associated with `lora_B`. |
| 188 | lora_bias = any("lora_B" in k and k.endswith(".bias") for k in peft_state_dict) |
| 189 | |
| 190 | lora_config_kwargs = { |
| 191 | "r": r, |
| 192 | "lora_alpha": lora_alpha, |
| 193 | "rank_pattern": rank_pattern, |
| 194 | "alpha_pattern": alpha_pattern, |
| 195 | "target_modules": target_modules, |
| 196 | "use_dora": use_dora, |
| 197 | "lora_bias": lora_bias, |
| 198 | } |
| 199 | |
| 200 | return lora_config_kwargs |
| 201 | |
| 202 | |
| 203 | def get_adapter_name(model): |
no test coverage detected
searching dependent graphs…