MCPcopy Index your code
hub / github.com/huggingface/diffusers / con_pt_to_diffuser

Function con_pt_to_diffuser

scripts/convert_consistency_to_diffusers.py:164–266  ·  view source on GitHub ↗
(checkpoint_path: str, unet_config)

Source from the content-addressed store, hash-verified

162
163
164def con_pt_to_diffuser(checkpoint_path: str, unet_config):
165 checkpoint = torch.load(checkpoint_path, map_location="cpu")
166 new_checkpoint = {}
167
168 new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"]
169 new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"]
170 new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"]
171 new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"]
172
173 if unet_config["num_class_embeds"] is not None:
174 new_checkpoint["class_embedding.weight"] = checkpoint["label_emb.weight"]
175
176 new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"]
177 new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"]
178
179 down_block_types = unet_config["down_block_types"]
180 layers_per_block = unet_config["layers_per_block"]
181 attention_head_dim = unet_config["attention_head_dim"]
182 channels_list = unet_config["block_out_channels"]
183 current_layer = 1
184 prev_channels = channels_list[0]
185
186 for i, layer_type in enumerate(down_block_types):
187 current_channels = channels_list[i]
188 downsample_block_has_skip = current_channels != prev_channels
189 if layer_type == "ResnetDownsampleBlock2D":
190 for j in range(layers_per_block):
191 new_prefix = f"down_blocks.{i}.resnets.{j}"
192 old_prefix = f"input_blocks.{current_layer}.0"
193 has_skip = True if j == 0 and downsample_block_has_skip else False
194 new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip)
195 current_layer += 1
196
197 elif layer_type == "AttnDownBlock2D":
198 for j in range(layers_per_block):
199 new_prefix = f"down_blocks.{i}.resnets.{j}"
200 old_prefix = f"input_blocks.{current_layer}.0"
201 has_skip = True if j == 0 and downsample_block_has_skip else False
202 new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip)
203 new_prefix = f"down_blocks.{i}.attentions.{j}"
204 old_prefix = f"input_blocks.{current_layer}.1"
205 new_checkpoint = convert_attention(
206 checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim
207 )
208 current_layer += 1
209
210 if i != len(down_block_types) - 1:
211 new_prefix = f"down_blocks.{i}.downsamplers.0"
212 old_prefix = f"input_blocks.{current_layer}.0"
213 new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
214 current_layer += 1
215
216 prev_channels = current_channels
217
218 # hardcoded the mid-block for now
219 new_prefix = "mid_block.resnets.0"
220 old_prefix = "middle_block.0"
221 new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)

Calls 3

convert_resnetFunction · 0.85
convert_attentionFunction · 0.85
loadMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…