Eval.
(self)
| 198 | self.merged = False |
| 199 | |
| 200 | def eval(self): |
| 201 | """Eval.""" |
| 202 | def T(w): |
| 203 | """T. |
| 204 | |
| 205 | Args: |
| 206 | w: TODO. |
| 207 | """ |
| 208 | return w.T if self.fan_in_fan_out else w |
| 209 | |
| 210 | nn.Linear.eval(self) |
| 211 | if self.merge_weights and not self.merged: |
| 212 | # Merge the weights and mark it |
| 213 | if self.r > 0: |
| 214 | self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling |
| 215 | self.merged = True |
| 216 | |
| 217 | def forward(self, x: torch.Tensor): |
| 218 | """Forward pass for training. |
no test coverage detected