| 17 | |
| 18 | |
| 19 | class AttnProcsLayers(torch.nn.Module): |
| 20 | def __init__(self, state_dict: dict[str, torch.Tensor]): |
| 21 | super().__init__() |
| 22 | self.layers = torch.nn.ModuleList(state_dict.values()) |
| 23 | self.mapping = dict(enumerate(state_dict.keys())) |
| 24 | self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} |
| 25 | |
| 26 | # .processor for unet, .self_attn for text encoder |
| 27 | self.split_keys = [".processor", ".self_attn"] |
| 28 | |
| 29 | # we add a hook to state_dict() and load_state_dict() so that the |
| 30 | # naming fits with `unet.attn_processors` |
| 31 | def map_to(module, state_dict, *args, **kwargs): |
| 32 | new_state_dict = {} |
| 33 | for key, value in state_dict.items(): |
| 34 | num = int(key.split(".")[1]) # 0 is always "layers" |
| 35 | new_key = key.replace(f"layers.{num}", module.mapping[num]) |
| 36 | new_state_dict[new_key] = value |
| 37 | |
| 38 | return new_state_dict |
| 39 | |
| 40 | def remap_key(key, state_dict): |
| 41 | for k in self.split_keys: |
| 42 | if k in key: |
| 43 | return key.split(k)[0] + k |
| 44 | |
| 45 | raise ValueError( |
| 46 | f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}." |
| 47 | ) |
| 48 | |
| 49 | def map_from(module, state_dict, *args, **kwargs): |
| 50 | all_keys = list(state_dict.keys()) |
| 51 | for key in all_keys: |
| 52 | replace_key = remap_key(key, state_dict) |
| 53 | new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") |
| 54 | state_dict[new_key] = state_dict[key] |
| 55 | del state_dict[key] |
| 56 | |
| 57 | self._register_state_dict_hook(map_to) |
| 58 | self._register_load_state_dict_pre_hook(map_from, with_module=True) |
no outgoing calls
no test coverage detected
searching dependent graphs…