This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits attention layers, and takes into account additional replacements that may arise. Assigns the weights to the new checkpoint.
(
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
)
| 137 | |
| 138 | |
| 139 | def assign_to_checkpoint( |
| 140 | paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None |
| 141 | ): |
| 142 | """ |
| 143 | This does the final conversion step: take locally converted weights and apply a global renaming |
| 144 | to them. It splits attention layers, and takes into account additional replacements |
| 145 | that may arise. |
| 146 | |
| 147 | Assigns the weights to the new checkpoint. |
| 148 | """ |
| 149 | assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." |
| 150 | |
| 151 | # Splits the attention layers into three variables. |
| 152 | if attention_paths_to_split is not None: |
| 153 | for path, path_map in attention_paths_to_split.items(): |
| 154 | old_tensor = old_checkpoint[path] |
| 155 | channels = old_tensor.shape[0] // 3 |
| 156 | |
| 157 | target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) |
| 158 | |
| 159 | num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 |
| 160 | |
| 161 | old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) |
| 162 | query, key, value = old_tensor.split(channels // num_heads, dim=1) |
| 163 | |
| 164 | checkpoint[path_map["query"]] = query.reshape(target_shape) |
| 165 | checkpoint[path_map["key"]] = key.reshape(target_shape) |
| 166 | checkpoint[path_map["value"]] = value.reshape(target_shape) |
| 167 | |
| 168 | for path in paths: |
| 169 | new_path = path["new"] |
| 170 | |
| 171 | # These have already been assigned |
| 172 | if attention_paths_to_split is not None and new_path in attention_paths_to_split: |
| 173 | continue |
| 174 | |
| 175 | # Global renaming happens here |
| 176 | new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") |
| 177 | new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") |
| 178 | new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") |
| 179 | |
| 180 | if additional_replacements is not None: |
| 181 | for replacement in additional_replacements: |
| 182 | new_path = new_path.replace(replacement["old"], replacement["new"]) |
| 183 | |
| 184 | # proj_attn.weight has to be converted from conv 1D to linear |
| 185 | if "proj_attn.weight" in new_path: |
| 186 | checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] |
| 187 | else: |
| 188 | checkpoint[new_path] = old_checkpoint[path["old"]] |
| 189 | |
| 190 | |
| 191 | def conv_attn_to_linear(checkpoint): |
no test coverage detected