(self, checkpoint_path, multiplier, device, dtype)
| 662 | |
| 663 | # https://github.com/huggingface/diffusers/issues/3064 |
| 664 | def load_lora_weights(self, checkpoint_path, multiplier, device, dtype): |
| 665 | LORA_PREFIX_UNET = "lora_unet" |
| 666 | LORA_PREFIX_TEXT_ENCODER = "lora_te" |
| 667 | # load LoRA weight from .safetensors |
| 668 | state_dict = load_file(checkpoint_path, device=device) |
| 669 | |
| 670 | updates = defaultdict(dict) |
| 671 | for key, value in state_dict.items(): |
| 672 | # it is suggested to print out the key, it usually will be something like below |
| 673 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" |
| 674 | |
| 675 | layer, elem = key.split('.', 1) |
| 676 | updates[layer][elem] = value |
| 677 | |
| 678 | # directly update weight in diffusers model |
| 679 | for layer, elems in updates.items(): |
| 680 | |
| 681 | if "text" in layer: |
| 682 | layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") |
| 683 | curr_layer = self.pipe.text_encoder |
| 684 | else: |
| 685 | layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_") |
| 686 | curr_layer = self.pipe.unet |
| 687 | |
| 688 | # find the target layer |
| 689 | temp_name = layer_infos.pop(0) |
| 690 | while len(layer_infos) > -1: |
| 691 | try: |
| 692 | curr_layer = curr_layer.__getattr__(temp_name) |
| 693 | if len(layer_infos) > 0: |
| 694 | temp_name = layer_infos.pop(0) |
| 695 | elif len(layer_infos) == 0: |
| 696 | break |
| 697 | except Exception: |
| 698 | if len(temp_name) > 0: |
| 699 | temp_name += "_" + layer_infos.pop(0) |
| 700 | else: |
| 701 | temp_name = layer_infos.pop(0) |
| 702 | |
| 703 | # get elements for this layer |
| 704 | weight_up = elems['lora_up.weight'].to(dtype) |
| 705 | weight_down = elems['lora_down.weight'].to(dtype) |
| 706 | alpha = elems['alpha'] if 'alpha' in elems else None |
| 707 | if alpha: |
| 708 | alpha = alpha.item() / weight_up.shape[1] |
| 709 | else: |
| 710 | alpha = 1.0 |
| 711 | |
| 712 | # update weight |
| 713 | if len(weight_up.shape) == 4: |
| 714 | curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) |
| 715 | else: |
| 716 | curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down) |
| 717 | |
| 718 | def GenerateImage(self, request, context): |
| 719 |
no test coverage detected