MCPcopy Index your code
hub / github.com/huggingface/diffusers / main

Function main

scripts/convert_hunyuandit_to_diffusers.py:8–236  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

6
7
8def main(args):
9 state_dict = torch.load(args.pt_checkpoint_path, map_location="cpu")
10
11 if args.load_key != "none":
12 try:
13 state_dict = state_dict[args.load_key]
14 except KeyError:
15 raise KeyError(
16 f"{args.load_key} not found in the checkpoint.Please load from the following keys:{state_dict.keys()}"
17 )
18
19 device = "cuda"
20 model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer")
21 model_config["use_style_cond_and_image_meta_size"] = (
22 args.use_style_cond_and_image_meta_size
23 ) ### version <= v1.1: True; version >= v1.2: False
24
25 # input_size -> sample_size, text_dim -> cross_attention_dim
26 for key in state_dict:
27 print("local:", key)
28
29 model = HunyuanDiT2DModel.from_config(model_config).to(device)
30
31 for key in model.state_dict():
32 print("diffusers:", key)
33
34 num_layers = 40
35 for i in range(num_layers):
36 # attn1
37 # Wkqv -> to_q, to_k, to_v
38 q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0)
39 q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0)
40 state_dict[f"blocks.{i}.attn1.to_q.weight"] = q
41 state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias
42 state_dict[f"blocks.{i}.attn1.to_k.weight"] = k
43 state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias
44 state_dict[f"blocks.{i}.attn1.to_v.weight"] = v
45 state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias
46 state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight")
47 state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias")
48
49 # q_norm, k_norm -> norm_q, norm_k
50 state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"]
51 state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"]
52 state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"]
53 state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"]
54
55 state_dict.pop(f"blocks.{i}.attn1.q_norm.weight")
56 state_dict.pop(f"blocks.{i}.attn1.q_norm.bias")
57 state_dict.pop(f"blocks.{i}.attn1.k_norm.weight")
58 state_dict.pop(f"blocks.{i}.attn1.k_norm.bias")
59
60 # out_proj -> to_out
61 state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"]
62 state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"]
63 state_dict.pop(f"blocks.{i}.attn1.out_proj.weight")
64 state_dict.pop(f"blocks.{i}.attn1.out_proj.bias")
65

Calls 12

pipeFunction · 0.85
load_configMethod · 0.80
saveMethod · 0.80
swap_scale_shiftFunction · 0.70
loadMethod · 0.45
toMethod · 0.45
from_configMethod · 0.45
state_dictMethod · 0.45
popMethod · 0.45
load_state_dictMethod · 0.45
from_pretrainedMethod · 0.45
save_pretrainedMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…