MCPcopy
hub / github.com/mudler/LocalAI / load_lora_weights

Method load_lora_weights

backend/python/diffusers/backend.py:664–716  ·  view source on GitHub ↗
(self, checkpoint_path, multiplier, device, dtype)

Source from the content-addressed store, hash-verified

662
663 # https://github.com/huggingface/diffusers/issues/3064
664 def load_lora_weights(self, checkpoint_path, multiplier, device, dtype):
665 LORA_PREFIX_UNET = "lora_unet"
666 LORA_PREFIX_TEXT_ENCODER = "lora_te"
667 # load LoRA weight from .safetensors
668 state_dict = load_file(checkpoint_path, device=device)
669
670 updates = defaultdict(dict)
671 for key, value in state_dict.items():
672 # it is suggested to print out the key, it usually will be something like below
673 # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
674
675 layer, elem = key.split('.', 1)
676 updates[layer][elem] = value
677
678 # directly update weight in diffusers model
679 for layer, elems in updates.items():
680
681 if "text" in layer:
682 layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
683 curr_layer = self.pipe.text_encoder
684 else:
685 layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
686 curr_layer = self.pipe.unet
687
688 # find the target layer
689 temp_name = layer_infos.pop(0)
690 while len(layer_infos) > -1:
691 try:
692 curr_layer = curr_layer.__getattr__(temp_name)
693 if len(layer_infos) > 0:
694 temp_name = layer_infos.pop(0)
695 elif len(layer_infos) == 0:
696 break
697 except Exception:
698 if len(temp_name) > 0:
699 temp_name += "_" + layer_infos.pop(0)
700 else:
701 temp_name = layer_infos.pop(0)
702
703 # get elements for this layer
704 weight_up = elems['lora_up.weight'].to(dtype)
705 weight_down = elems['lora_down.weight'].to(dtype)
706 alpha = elems['alpha'] if 'alpha' in elems else None
707 if alpha:
708 alpha = alpha.item() / weight_up.shape[1]
709 else:
710 alpha = 1.0
711
712 # update weight
713 if len(weight_up.shape) == 4:
714 curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
715 else:
716 curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
717
718 def GenerateImage(self, request, context):
719

Callers 1

LoadModelMethod · 0.80

Calls 4

popMethod · 0.80
itemsMethod · 0.45
__getattr__Method · 0.45
toMethod · 0.45

Tested by

no test coverage detected