Zero pad. Args: x: TODO.
(self, x)
| 301 | nn.init.zeros_(self.lora_B) |
| 302 | |
| 303 | def zero_pad(self, x): |
| 304 | """Zero pad. |
| 305 | |
| 306 | Args: |
| 307 | x: TODO. |
| 308 | """ |
| 309 | result = x.new_zeros((*x.shape[:-1], self.out_features)) |
| 310 | result = result.view(-1, self.out_features) |
| 311 | result[:, self.lora_ind] = x.reshape( |
| 312 | -1, self.out_features // len(self.enable_lora) * sum(self.enable_lora) |
| 313 | ) |
| 314 | return result.view((*x.shape[:-1], self.out_features)) |
| 315 | |
| 316 | def train(self, mode: bool = True): |
| 317 | """Train. |