Mark only lora as trainable. Args: model: Model instance or model name. bias: TODO.
(model: nn.Module, bias: str = "none")
| 11 | |
| 12 | |
| 13 | def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None: |
| 14 | """Mark only lora as trainable. |
| 15 | |
| 16 | Args: |
| 17 | model: Model instance or model name. |
| 18 | bias: TODO. |
| 19 | """ |
| 20 | for n, p in model.named_parameters(): |
| 21 | if "lora_" not in n and "cif" not in n: |
| 22 | p.requires_grad = False |
| 23 | if bias == "none": |
| 24 | return |
| 25 | elif bias == "all": |
| 26 | for n, p in model.named_parameters(): |
| 27 | if "bias" in n: |
| 28 | p.requires_grad = True |
| 29 | elif bias == "lora_only": |
| 30 | for m in model.modules(): |
| 31 | if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None: |
| 32 | m.bias.requires_grad = True |
| 33 | else: |
| 34 | raise NotImplementedError |
| 35 | |
| 36 | |
| 37 | def lora_state_dict(model: nn.Module, bias: str = "none") -> Dict[str, torch.Tensor]: |