Forward pass for training. Args: x: TODO.
(self, x: torch.Tensor)
| 362 | self.merged = True |
| 363 | |
| 364 | def forward(self, x: torch.Tensor): |
| 365 | """Forward pass for training. |
| 366 | |
| 367 | Args: |
| 368 | x: TODO. |
| 369 | """ |
| 370 | def T(w): |
| 371 | """T. |
| 372 | |
| 373 | Args: |
| 374 | w: TODO. |
| 375 | """ |
| 376 | return w.T if self.fan_in_fan_out else w |
| 377 | |
| 378 | if self.merged: |
| 379 | return F.linear(x, T(self.weight), bias=self.bias) |
| 380 | else: |
| 381 | result = F.linear(x, T(self.weight), bias=self.bias) |
| 382 | if self.r > 0: |
| 383 | after_A = F.linear(self.lora_dropout(x), self.lora_A) |
| 384 | after_B = F.conv1d( |
| 385 | after_A.transpose(-2, -1), |
| 386 | self.lora_B.unsqueeze(-1), |
| 387 | groups=sum(self.enable_lora), |
| 388 | ).transpose(-2, -1) |
| 389 | result += self.zero_pad(after_B) * self.scaling |
| 390 | return result |
| 391 | |
| 392 | |
| 393 | class Conv2d(nn.Conv2d, LoRALayer): |