MCPcopy Index your code
hub / github.com/huggingface/diffusers / recurse_remove_peft_layers

Function recurse_remove_peft_layers

src/diffusers/utils/peft_utils.py:35–103  ·  view source on GitHub ↗

r""" Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`.

(model)

Source from the content-addressed store, hash-verified

33
34
35def recurse_remove_peft_layers(model):
36 r"""
37 Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`.
38 """
39 from peft.tuners.tuners_utils import BaseTunerLayer
40
41 has_base_layer_pattern = False
42 for module in model.modules():
43 if isinstance(module, BaseTunerLayer):
44 has_base_layer_pattern = hasattr(module, "base_layer")
45 break
46
47 if has_base_layer_pattern:
48 from peft.utils import _get_submodules
49
50 key_list = [key for key, _ in model.named_modules() if "lora" not in key]
51 for key in key_list:
52 try:
53 parent, target, target_name = _get_submodules(model, key)
54 except AttributeError:
55 continue
56 if hasattr(target, "base_layer"):
57 setattr(parent, target_name, target.get_base_layer())
58 else:
59 # This is for backwards compatibility with PEFT <= 0.6.2.
60 # TODO can be removed once that PEFT version is no longer supported.
61 from peft.tuners.lora import LoraLayer
62
63 for name, module in model.named_children():
64 if len(list(module.children())) > 0:
65 ## compound module, go inside it
66 recurse_remove_peft_layers(module)
67
68 module_replaced = False
69
70 if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
71 new_module = torch.nn.Linear(
72 module.in_features,
73 module.out_features,
74 bias=module.bias is not None,
75 ).to(module.weight.device)
76 new_module.weight = module.weight
77 if module.bias is not None:
78 new_module.bias = module.bias
79
80 module_replaced = True
81 elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
82 new_module = torch.nn.Conv2d(
83 module.in_channels,
84 module.out_channels,
85 module.kernel_size,
86 module.stride,
87 module.padding,
88 module.dilation,
89 module.groups,
90 ).to(module.weight.device)
91
92 new_module.weight = module.weight

Callers 2

unload_loraMethod · 0.85

Calls 2

empty_device_cacheFunction · 0.85
toMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…