Initialize MergedLinear. Args: in_features: TODO. out_features: TODO. r: TODO. lora_alpha: TODO. lora_dropout: TODO. enable_lora: TODO. fan_in_fan_out: TODO.
(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
enable_lora: List[bool] = [False],
fan_in_fan_out: bool = False,
merge_weights: bool = True,
**kwargs
)
| 240 | class MergedLinear(nn.Linear, LoRALayer): |
| 241 | # LoRA implemented in a dense layer |
| 242 | def __init__( |
| 243 | self, |
| 244 | in_features: int, |
| 245 | out_features: int, |
| 246 | r: int = 0, |
| 247 | lora_alpha: int = 1, |
| 248 | lora_dropout: float = 0.0, |
| 249 | enable_lora: List[bool] = [False], |
| 250 | fan_in_fan_out: bool = False, |
| 251 | merge_weights: bool = True, |
| 252 | **kwargs |
| 253 | ): |
| 254 | """Initialize MergedLinear. |
| 255 | |
| 256 | Args: |
| 257 | in_features: TODO. |
| 258 | out_features: TODO. |
| 259 | r: TODO. |
| 260 | lora_alpha: TODO. |
| 261 | lora_dropout: TODO. |
| 262 | enable_lora: TODO. |
| 263 | fan_in_fan_out: TODO. |
| 264 | merge_weights: TODO. |
| 265 | **kwargs: Additional keyword arguments. |
| 266 | """ |
| 267 | nn.Linear.__init__(self, in_features, out_features, **kwargs) |
| 268 | LoRALayer.__init__( |
| 269 | self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights |
| 270 | ) |
| 271 | assert ( |
| 272 | out_features % len(enable_lora) == 0 |
| 273 | ), "The length of enable_lora must divide out_features" |
| 274 | self.enable_lora = enable_lora |
| 275 | self.fan_in_fan_out = fan_in_fan_out |
| 276 | # Actual trainable parameters |
| 277 | if r > 0 and any(enable_lora): |
| 278 | self.lora_A = nn.Parameter(self.weight.new_zeros((r * sum(enable_lora), in_features))) |
| 279 | self.lora_B = nn.Parameter( |
| 280 | self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r)) |
| 281 | ) # weights for Conv1D with groups=sum(enable_lora) |
| 282 | self.scaling = self.lora_alpha / self.r |
| 283 | # Freezing the pre-trained weight matrix |
| 284 | self.weight.requires_grad = False |
| 285 | # Compute the indices |
| 286 | self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view( |
| 287 | len(enable_lora), -1 |
| 288 | ) |
| 289 | self.lora_ind[enable_lora, :] = True |
| 290 | self.lora_ind = self.lora_ind.view(-1) |
| 291 | self.reset_parameters() |
| 292 | if fan_in_fan_out: |
| 293 | self.weight.data = self.weight.data.T |
| 294 | |
| 295 | def reset_parameters(self): |
| 296 | """Reset parameters.""" |
nothing calls this directly
no test coverage detected