MCPcopy Index your code
hub / github.com/huggingface/diffusers / main

Function main

scripts/convert_dit_to_diffusers.py:26–134  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

24
25
26def main(args):
27 state_dict = download_model(pretrained_models[args.image_size])
28
29 state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
30 state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
31 state_dict.pop("x_embedder.proj.weight")
32 state_dict.pop("x_embedder.proj.bias")
33
34 for depth in range(28):
35 state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.weight"] = state_dict[
36 "t_embedder.mlp.0.weight"
37 ]
38 state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.bias"] = state_dict[
39 "t_embedder.mlp.0.bias"
40 ]
41 state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.weight"] = state_dict[
42 "t_embedder.mlp.2.weight"
43 ]
44 state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.bias"] = state_dict[
45 "t_embedder.mlp.2.bias"
46 ]
47 state_dict[f"transformer_blocks.{depth}.norm1.emb.class_embedder.embedding_table.weight"] = state_dict[
48 "y_embedder.embedding_table.weight"
49 ]
50
51 state_dict[f"transformer_blocks.{depth}.norm1.linear.weight"] = state_dict[
52 f"blocks.{depth}.adaLN_modulation.1.weight"
53 ]
54 state_dict[f"transformer_blocks.{depth}.norm1.linear.bias"] = state_dict[
55 f"blocks.{depth}.adaLN_modulation.1.bias"
56 ]
57
58 q, k, v = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.weight"], 3, dim=0)
59 q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.bias"], 3, dim=0)
60
61 state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
62 state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias
63 state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
64 state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias
65 state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
66 state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias
67
68 state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict[
69 f"blocks.{depth}.attn.proj.weight"
70 ]
71 state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict[f"blocks.{depth}.attn.proj.bias"]
72
73 state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict[f"blocks.{depth}.mlp.fc1.weight"]
74 state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict[f"blocks.{depth}.mlp.fc1.bias"]
75 state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict[f"blocks.{depth}.mlp.fc2.weight"]
76 state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict[f"blocks.{depth}.mlp.fc2.bias"]
77
78 state_dict.pop(f"blocks.{depth}.attn.qkv.weight")
79 state_dict.pop(f"blocks.{depth}.attn.qkv.bias")
80 state_dict.pop(f"blocks.{depth}.attn.proj.weight")
81 state_dict.pop(f"blocks.{depth}.attn.proj.bias")
82 state_dict.pop(f"blocks.{depth}.mlp.fc1.weight")
83 state_dict.pop(f"blocks.{depth}.mlp.fc1.bias")

Callers 1

Calls 8

Transformer2DModelClass · 0.90
DDIMSchedulerClass · 0.90
DiTPipelineClass · 0.90
download_modelFunction · 0.85
popMethod · 0.45
load_state_dictMethod · 0.45
from_pretrainedMethod · 0.45
save_pretrainedMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…