r""" Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`.
(model)
| 33 | |
| 34 | |
| 35 | def recurse_remove_peft_layers(model): |
| 36 | r""" |
| 37 | Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`. |
| 38 | """ |
| 39 | from peft.tuners.tuners_utils import BaseTunerLayer |
| 40 | |
| 41 | has_base_layer_pattern = False |
| 42 | for module in model.modules(): |
| 43 | if isinstance(module, BaseTunerLayer): |
| 44 | has_base_layer_pattern = hasattr(module, "base_layer") |
| 45 | break |
| 46 | |
| 47 | if has_base_layer_pattern: |
| 48 | from peft.utils import _get_submodules |
| 49 | |
| 50 | key_list = [key for key, _ in model.named_modules() if "lora" not in key] |
| 51 | for key in key_list: |
| 52 | try: |
| 53 | parent, target, target_name = _get_submodules(model, key) |
| 54 | except AttributeError: |
| 55 | continue |
| 56 | if hasattr(target, "base_layer"): |
| 57 | setattr(parent, target_name, target.get_base_layer()) |
| 58 | else: |
| 59 | # This is for backwards compatibility with PEFT <= 0.6.2. |
| 60 | # TODO can be removed once that PEFT version is no longer supported. |
| 61 | from peft.tuners.lora import LoraLayer |
| 62 | |
| 63 | for name, module in model.named_children(): |
| 64 | if len(list(module.children())) > 0: |
| 65 | ## compound module, go inside it |
| 66 | recurse_remove_peft_layers(module) |
| 67 | |
| 68 | module_replaced = False |
| 69 | |
| 70 | if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear): |
| 71 | new_module = torch.nn.Linear( |
| 72 | module.in_features, |
| 73 | module.out_features, |
| 74 | bias=module.bias is not None, |
| 75 | ).to(module.weight.device) |
| 76 | new_module.weight = module.weight |
| 77 | if module.bias is not None: |
| 78 | new_module.bias = module.bias |
| 79 | |
| 80 | module_replaced = True |
| 81 | elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d): |
| 82 | new_module = torch.nn.Conv2d( |
| 83 | module.in_channels, |
| 84 | module.out_channels, |
| 85 | module.kernel_size, |
| 86 | module.stride, |
| 87 | module.padding, |
| 88 | module.dilation, |
| 89 | module.groups, |
| 90 | ).to(module.weight.device) |
| 91 | |
| 92 | new_module.weight = module.weight |
no test coverage detected
searching dependent graphs…