(checkpoint_path: str, unet_config)
| 162 | |
| 163 | |
| 164 | def con_pt_to_diffuser(checkpoint_path: str, unet_config): |
| 165 | checkpoint = torch.load(checkpoint_path, map_location="cpu") |
| 166 | new_checkpoint = {} |
| 167 | |
| 168 | new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"] |
| 169 | new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"] |
| 170 | new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"] |
| 171 | new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"] |
| 172 | |
| 173 | if unet_config["num_class_embeds"] is not None: |
| 174 | new_checkpoint["class_embedding.weight"] = checkpoint["label_emb.weight"] |
| 175 | |
| 176 | new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"] |
| 177 | new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"] |
| 178 | |
| 179 | down_block_types = unet_config["down_block_types"] |
| 180 | layers_per_block = unet_config["layers_per_block"] |
| 181 | attention_head_dim = unet_config["attention_head_dim"] |
| 182 | channels_list = unet_config["block_out_channels"] |
| 183 | current_layer = 1 |
| 184 | prev_channels = channels_list[0] |
| 185 | |
| 186 | for i, layer_type in enumerate(down_block_types): |
| 187 | current_channels = channels_list[i] |
| 188 | downsample_block_has_skip = current_channels != prev_channels |
| 189 | if layer_type == "ResnetDownsampleBlock2D": |
| 190 | for j in range(layers_per_block): |
| 191 | new_prefix = f"down_blocks.{i}.resnets.{j}" |
| 192 | old_prefix = f"input_blocks.{current_layer}.0" |
| 193 | has_skip = True if j == 0 and downsample_block_has_skip else False |
| 194 | new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip) |
| 195 | current_layer += 1 |
| 196 | |
| 197 | elif layer_type == "AttnDownBlock2D": |
| 198 | for j in range(layers_per_block): |
| 199 | new_prefix = f"down_blocks.{i}.resnets.{j}" |
| 200 | old_prefix = f"input_blocks.{current_layer}.0" |
| 201 | has_skip = True if j == 0 and downsample_block_has_skip else False |
| 202 | new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip) |
| 203 | new_prefix = f"down_blocks.{i}.attentions.{j}" |
| 204 | old_prefix = f"input_blocks.{current_layer}.1" |
| 205 | new_checkpoint = convert_attention( |
| 206 | checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim |
| 207 | ) |
| 208 | current_layer += 1 |
| 209 | |
| 210 | if i != len(down_block_types) - 1: |
| 211 | new_prefix = f"down_blocks.{i}.downsamplers.0" |
| 212 | old_prefix = f"input_blocks.{current_layer}.0" |
| 213 | new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) |
| 214 | current_layer += 1 |
| 215 | |
| 216 | prev_channels = current_channels |
| 217 | |
| 218 | # hardcoded the mid-block for now |
| 219 | new_prefix = "mid_block.resnets.0" |
| 220 | old_prefix = "middle_block.0" |
| 221 | new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) |
no test coverage detected
searching dependent graphs…