MCPcopy
hub / github.com/kohya-ss/sd-scripts / state_dict

Method state_dict

networks/dylora.py:123–149  ·  view source on GitHub ↗
(self, destination=None, prefix="", keep_vars=False)

Source from the content-addressed store, hash-verified

121 return result
122
123 def state_dict(self, destination=None, prefix="", keep_vars=False):
124 # state dictを通常のLoRAと同じにする:
125 # nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える
126 sd = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
127
128 lora_A_weight = torch.cat(tuple(self.lora_A), dim=0)
129 if self.is_conv2d and not self.is_conv2d_3x3:
130 lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1)
131
132 lora_B_weight = torch.cat(tuple(self.lora_B), dim=1)
133 if self.is_conv2d and not self.is_conv2d_3x3:
134 lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1)
135
136 sd[self.lora_name + ".lora_down.weight"] = lora_A_weight if keep_vars else lora_A_weight.detach()
137 sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach()
138
139 i = 0
140 while True:
141 key_a = f"{self.lora_name}.lora_A.{i}"
142 key_b = f"{self.lora_name}.lora_B.{i}"
143 if key_a in sd:
144 sd.pop(key_a)
145 sd.pop(key_b)
146 else:
147 break
148 i += 1
149 return sd
150
151 def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
152 # 通常のLoRAと同じstate dictを読み込めるようにする:この方法はchatGPTに聞いた

Callers 15

mainFunction · 0.45
save_modelFunction · 0.45
mainFunction · 0.45
trainFunction · 0.45
save_modelFunction · 0.45
trainFunction · 0.45
mainFunction · 0.45
load_dit_modelFunction · 0.45
merge_lora_weightsFunction · 0.45
load_control_netFunction · 0.45
mainFunction · 0.45
merge_toMethod · 0.45

Calls

no outgoing calls

Tested by 2

mainFunction · 0.36