Train. Args: mode: TODO.
(self, mode: bool = True)
| 314 | return result.view((*x.shape[:-1], self.out_features)) |
| 315 | |
| 316 | def train(self, mode: bool = True): |
| 317 | """Train. |
| 318 | |
| 319 | Args: |
| 320 | mode: TODO. |
| 321 | """ |
| 322 | def T(w): |
| 323 | """T. |
| 324 | |
| 325 | Args: |
| 326 | w: TODO. |
| 327 | """ |
| 328 | return w.T if self.fan_in_fan_out else w |
| 329 | |
| 330 | nn.Linear.train(self, mode) |
| 331 | if self.merge_weights and self.merged: |
| 332 | # Make sure that the weights are not merged |
| 333 | if self.r > 0 and any(self.enable_lora): |
| 334 | delta_w = F.conv1d( |
| 335 | self.lora_A.data.unsqueeze(0), |
| 336 | self.lora_B.data.unsqueeze(-1), |
| 337 | groups=sum(self.enable_lora), |
| 338 | ).squeeze(0) |
| 339 | self.weight.data -= self.zero_pad(T(delta_w * self.scaling)) |
| 340 | self.merged = False |
| 341 | |
| 342 | def eval(self): |
| 343 | """Eval.""" |