MCPcopy
hub / github.com/huggingface/diffusers / main

Function main

scripts/convert_sana_controlnet_to_diffusers.py:20–152  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

18
19
20def main(args):
21 file_path = args.orig_ckpt_path
22
23 all_state_dict = torch.load(file_path, weights_only=True)
24 state_dict = all_state_dict.pop("state_dict")
25 converted_state_dict = {}
26
27 # Patch embeddings.
28 converted_state_dict["patch_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
29 converted_state_dict["patch_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
30
31 # Caption projection.
32 converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
33 converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
34 converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
35 converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
36
37 # AdaLN-single LN
38 converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
39 "t_embedder.mlp.0.weight"
40 )
41 converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
42 converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
43 "t_embedder.mlp.2.weight"
44 )
45 converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
46
47 # Shared norm.
48 converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
49 converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias")
50
51 # y norm
52 converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
53
54 # Positional embedding interpolation scale.
55 interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
56
57 # ControlNet Input Projection.
58 converted_state_dict["input_block.weight"] = state_dict.pop("controlnet.0.before_proj.weight")
59 converted_state_dict["input_block.bias"] = state_dict.pop("controlnet.0.before_proj.bias")
60
61 for depth in range(7):
62 # Transformer blocks.
63 converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
64 f"controlnet.{depth}.copied_block.scale_shift_table"
65 )
66
67 # Linear Attention is all you need 🤘
68 # Self attention.
69 q, k, v = torch.chunk(state_dict.pop(f"controlnet.{depth}.copied_block.attn.qkv.weight"), 3, dim=0)
70 converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
71 converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
72 converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
73 # Projection.
74 converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
75 f"controlnet.{depth}.copied_block.attn.proj.weight"
76 )
77 converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(

Calls 9

SanaControlNetModelClass · 0.90
is_accelerate_availableFunction · 0.90
parametersMethod · 0.80
loadMethod · 0.45
popMethod · 0.45
load_state_dictMethod · 0.45
toMethod · 0.45
save_pretrainedMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…