MCPcopy
hub / github.com/huggingface/diffusers / AttnProcsLayers

Class AttnProcsLayers

src/diffusers/loaders/utils.py:19–58  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

17
18
19class AttnProcsLayers(torch.nn.Module):
20 def __init__(self, state_dict: dict[str, torch.Tensor]):
21 super().__init__()
22 self.layers = torch.nn.ModuleList(state_dict.values())
23 self.mapping = dict(enumerate(state_dict.keys()))
24 self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
25
26 # .processor for unet, .self_attn for text encoder
27 self.split_keys = [".processor", ".self_attn"]
28
29 # we add a hook to state_dict() and load_state_dict() so that the
30 # naming fits with `unet.attn_processors`
31 def map_to(module, state_dict, *args, **kwargs):
32 new_state_dict = {}
33 for key, value in state_dict.items():
34 num = int(key.split(".")[1]) # 0 is always "layers"
35 new_key = key.replace(f"layers.{num}", module.mapping[num])
36 new_state_dict[new_key] = value
37
38 return new_state_dict
39
40 def remap_key(key, state_dict):
41 for k in self.split_keys:
42 if k in key:
43 return key.split(k)[0] + k
44
45 raise ValueError(
46 f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}."
47 )
48
49 def map_from(module, state_dict, *args, **kwargs):
50 all_keys = list(state_dict.keys())
51 for key in all_keys:
52 replace_key = remap_key(key, state_dict)
53 new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
54 state_dict[new_key] = state_dict[key]
55 del state_dict[key]
56
57 self._register_state_dict_hook(map_to)
58 self._register_load_state_dict_pre_hook(map_from, with_module=True)

Callers 7

mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
train_loraMethod · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…