MCPcopy
hub / github.com/huggingface/diffusers / get_peft_kwargs

Function get_peft_kwargs

src/diffusers/utils/peft_utils.py:153–200  ·  view source on GitHub ↗
(
    rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None
)

Source from the content-addressed store, hash-verified

151
152
153def get_peft_kwargs(
154 rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None
155):
156 rank_pattern = {}
157 alpha_pattern = {}
158 r = lora_alpha = list(rank_dict.values())[0]
159
160 if len(set(rank_dict.values())) > 1:
161 # get the rank occurring the most number of times
162 r = collections.Counter(rank_dict.values()).most_common()[0][0]
163
164 # for modules with rank different from the most occurring rank, add it to the `rank_pattern`
165 rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
166 rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
167
168 if network_alpha_dict is not None and len(network_alpha_dict) > 0:
169 if len(set(network_alpha_dict.values())) > 1:
170 # get the alpha occurring the most number of times
171 lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
172
173 # for modules with alpha different from the most occurring alpha, add it to the `alpha_pattern`
174 alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
175 if is_unet:
176 alpha_pattern = {
177 ".".join(k.split(".lora_A.")[0].split(".")).replace(".alpha", ""): v
178 for k, v in alpha_pattern.items()
179 }
180 else:
181 alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
182 else:
183 lora_alpha = set(network_alpha_dict.values()).pop()
184
185 target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
186 use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
187 # for now we know that the "bias" keys are only associated with `lora_B`.
188 lora_bias = any("lora_B" in k and k.endswith(".bias") for k in peft_state_dict)
189
190 lora_config_kwargs = {
191 "r": r,
192 "lora_alpha": lora_alpha,
193 "rank_pattern": rank_pattern,
194 "alpha_pattern": alpha_pattern,
195 "target_modules": target_modules,
196 "use_dora": use_dora,
197 "lora_bias": lora_bias,
198 }
199
200 return lora_config_kwargs
201
202
203def get_adapter_name(model):

Callers 2

_create_lora_configFunction · 0.85
_process_loraMethod · 0.85

Calls 2

splitMethod · 0.80
popMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…