Train. Args: mode: TODO.
(self, mode: bool = True)
| 177 | nn.init.zeros_(self.lora_B) |
| 178 | |
| 179 | def train(self, mode: bool = True): |
| 180 | """Train. |
| 181 | |
| 182 | Args: |
| 183 | mode: TODO. |
| 184 | """ |
| 185 | def T(w): |
| 186 | """T. |
| 187 | |
| 188 | Args: |
| 189 | w: TODO. |
| 190 | """ |
| 191 | return w.T if self.fan_in_fan_out else w |
| 192 | |
| 193 | nn.Linear.train(self, mode) |
| 194 | if self.merge_weights and self.merged: |
| 195 | # Make sure that the weights are not merged |
| 196 | if self.r > 0: |
| 197 | self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling |
| 198 | self.merged = False |
| 199 | |
| 200 | def eval(self): |
| 201 | """Eval.""" |
no test coverage detected