MCPcopy
hub / github.com/TheLastBen/fast-stable-diffusion / assign_to_checkpoint

Function assign_to_checkpoint

Dreambooth/convertodiffv2.py:139–188  ·  view source on GitHub ↗

This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits attention layers, and takes into account additional replacements that may arise. Assigns the weights to the new checkpoint.

(
    paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
)

Source from the content-addressed store, hash-verified

137
138
139def assign_to_checkpoint(
140 paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
141):
142 """
143 This does the final conversion step: take locally converted weights and apply a global renaming
144 to them. It splits attention layers, and takes into account additional replacements
145 that may arise.
146
147 Assigns the weights to the new checkpoint.
148 """
149 assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
150
151 # Splits the attention layers into three variables.
152 if attention_paths_to_split is not None:
153 for path, path_map in attention_paths_to_split.items():
154 old_tensor = old_checkpoint[path]
155 channels = old_tensor.shape[0] // 3
156
157 target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
158
159 num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
160
161 old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
162 query, key, value = old_tensor.split(channels // num_heads, dim=1)
163
164 checkpoint[path_map["query"]] = query.reshape(target_shape)
165 checkpoint[path_map["key"]] = key.reshape(target_shape)
166 checkpoint[path_map["value"]] = value.reshape(target_shape)
167
168 for path in paths:
169 new_path = path["new"]
170
171 # These have already been assigned
172 if attention_paths_to_split is not None and new_path in attention_paths_to_split:
173 continue
174
175 # Global renaming happens here
176 new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
177 new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
178 new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
179
180 if additional_replacements is not None:
181 for replacement in additional_replacements:
182 new_path = new_path.replace(replacement["old"], replacement["new"])
183
184 # proj_attn.weight has to be converted from conv 1D to linear
185 if "proj_attn.weight" in new_path:
186 checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
187 else:
188 checkpoint[new_path] = old_checkpoint[path["old"]]
189
190
191def conv_attn_to_linear(checkpoint):

Callers 2

Calls 1

replaceMethod · 0.80

Tested by

no test coverage detected