(self, destination=None, prefix="", keep_vars=False)
| 121 | return result |
| 122 | |
| 123 | def state_dict(self, destination=None, prefix="", keep_vars=False): |
| 124 | # state dictを通常のLoRAと同じにする: |
| 125 | # nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える |
| 126 | sd = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) |
| 127 | |
| 128 | lora_A_weight = torch.cat(tuple(self.lora_A), dim=0) |
| 129 | if self.is_conv2d and not self.is_conv2d_3x3: |
| 130 | lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1) |
| 131 | |
| 132 | lora_B_weight = torch.cat(tuple(self.lora_B), dim=1) |
| 133 | if self.is_conv2d and not self.is_conv2d_3x3: |
| 134 | lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1) |
| 135 | |
| 136 | sd[self.lora_name + ".lora_down.weight"] = lora_A_weight if keep_vars else lora_A_weight.detach() |
| 137 | sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach() |
| 138 | |
| 139 | i = 0 |
| 140 | while True: |
| 141 | key_a = f"{self.lora_name}.lora_A.{i}" |
| 142 | key_b = f"{self.lora_name}.lora_B.{i}" |
| 143 | if key_a in sd: |
| 144 | sd.pop(key_a) |
| 145 | sd.pop(key_b) |
| 146 | else: |
| 147 | break |
| 148 | i += 1 |
| 149 | return sd |
| 150 | |
| 151 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): |
| 152 | # 通常のLoRAと同じstate dictを読み込めるようにする:この方法はchatGPTに聞いた |
no outgoing calls