(self, model: torch.nn.Module, state_dict_lora, alpha=1.0)
| 26 | |
| 27 | |
| 28 | def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): |
| 29 | updated_num = 0 |
| 30 | lora_name_dict = self.get_name_dict(state_dict_lora) |
| 31 | for name, module in model.named_modules(): |
| 32 | if name in lora_name_dict: |
| 33 | weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype) |
| 34 | weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype) |
| 35 | if len(weight_up.shape) == 4: |
| 36 | weight_up = weight_up.squeeze(3).squeeze(2) |
| 37 | weight_down = weight_down.squeeze(3).squeeze(2) |
| 38 | weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) |
| 39 | else: |
| 40 | weight_lora = alpha * torch.mm(weight_up, weight_down) |
| 41 | state_dict = module.state_dict() |
| 42 | state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora |
| 43 | module.load_state_dict(state_dict) |
| 44 | updated_num += 1 |
| 45 | print(f"{updated_num} tensors are updated by LoRA.") |
no test coverage detected