MCPcopy Index your code
hub / github.com/huggingface/diffusers / main

Function main

scripts/convert_hunyuandit_controlnet_to_diffusers.py:8–211  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

6
7
8def main(args):
9 state_dict = torch.load(args.pt_checkpoint_path, map_location="cpu")
10
11 if args.load_key != "none":
12 try:
13 state_dict = state_dict[args.load_key]
14 except KeyError:
15 raise KeyError(
16 f"{args.load_key} not found in the checkpoint."
17 "Please load from the following keys:{state_dict.keys()}"
18 )
19 device = "cuda"
20
21 model_config = HunyuanDiT2DControlNetModel.load_config(
22 "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer"
23 )
24 model_config["use_style_cond_and_image_meta_size"] = (
25 args.use_style_cond_and_image_meta_size
26 ) ### version <= v1.1: True; version >= v1.2: False
27 print(model_config)
28
29 for key in state_dict:
30 print("local:", key)
31
32 model = HunyuanDiT2DControlNetModel.from_config(model_config).to(device)
33
34 for key in model.state_dict():
35 print("diffusers:", key)
36
37 num_layers = 19
38 for i in range(num_layers):
39 # attn1
40 # Wkqv -> to_q, to_k, to_v
41 q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0)
42 q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0)
43 state_dict[f"blocks.{i}.attn1.to_q.weight"] = q
44 state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias
45 state_dict[f"blocks.{i}.attn1.to_k.weight"] = k
46 state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias
47 state_dict[f"blocks.{i}.attn1.to_v.weight"] = v
48 state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias
49 state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight")
50 state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias")
51
52 # q_norm, k_norm -> norm_q, norm_k
53 state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"]
54 state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"]
55 state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"]
56 state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"]
57
58 state_dict.pop(f"blocks.{i}.attn1.q_norm.weight")
59 state_dict.pop(f"blocks.{i}.attn1.q_norm.bias")
60 state_dict.pop(f"blocks.{i}.attn1.k_norm.weight")
61 state_dict.pop(f"blocks.{i}.attn1.k_norm.bias")
62
63 # out_proj -> to_out
64 state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"]
65 state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"]

Calls 8

load_configMethod · 0.80
loadMethod · 0.45
toMethod · 0.45
from_configMethod · 0.45
state_dictMethod · 0.45
popMethod · 0.45
load_state_dictMethod · 0.45
save_pretrainedMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…