Eval.
(self)
| 340 | self.merged = False |
| 341 | |
| 342 | def eval(self): |
| 343 | """Eval.""" |
| 344 | def T(w): |
| 345 | """T. |
| 346 | |
| 347 | Args: |
| 348 | w: TODO. |
| 349 | """ |
| 350 | return w.T if self.fan_in_fan_out else w |
| 351 | |
| 352 | nn.Linear.eval(self) |
| 353 | if self.merge_weights and not self.merged: |
| 354 | # Merge the weights and mark it |
| 355 | if self.r > 0 and any(self.enable_lora): |
| 356 | delta_w = F.conv1d( |
| 357 | self.lora_A.data.unsqueeze(0), |
| 358 | self.lora_B.data.unsqueeze(-1), |
| 359 | groups=sum(self.enable_lora), |
| 360 | ).squeeze(0) |
| 361 | self.weight.data += self.zero_pad(T(delta_w * self.scaling)) |
| 362 | self.merged = True |
| 363 | |
| 364 | def forward(self, x: torch.Tensor): |
| 365 | """Forward pass for training. |