MCPcopy Index your code
hub / github.com/modelscope/DiffSynth-Studio / load

Method load

diffsynth/lora/__init__.py:28–45  ·  view source on GitHub ↗
(self, model: torch.nn.Module, state_dict_lora, alpha=1.0)

Source from the content-addressed store, hash-verified

26
27
28 def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
29 updated_num = 0
30 lora_name_dict = self.get_name_dict(state_dict_lora)
31 for name, module in model.named_modules():
32 if name in lora_name_dict:
33 weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype)
34 weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype)
35 if len(weight_up.shape) == 4:
36 weight_up = weight_up.squeeze(3).squeeze(2)
37 weight_down = weight_down.squeeze(3).squeeze(2)
38 weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
39 else:
40 weight_lora = alpha * torch.mm(weight_up, weight_down)
41 state_dict = module.state_dict()
42 state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora
43 module.load_state_dict(state_dict)
44 updated_num += 1
45 print(f"{updated_num} tensors are updated by LoRA.")

Callers 2

load_loraMethod · 0.95
load_loraMethod · 0.95

Calls 2

get_name_dictMethod · 0.95
toMethod · 0.45

Tested by

no test coverage detected