(args)
| 6 | |
| 7 | |
| 8 | def main(args): |
| 9 | state_dict = torch.load(args.pt_checkpoint_path, map_location="cpu") |
| 10 | |
| 11 | if args.load_key != "none": |
| 12 | try: |
| 13 | state_dict = state_dict[args.load_key] |
| 14 | except KeyError: |
| 15 | raise KeyError( |
| 16 | f"{args.load_key} not found in the checkpoint." |
| 17 | "Please load from the following keys:{state_dict.keys()}" |
| 18 | ) |
| 19 | device = "cuda" |
| 20 | |
| 21 | model_config = HunyuanDiT2DControlNetModel.load_config( |
| 22 | "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer" |
| 23 | ) |
| 24 | model_config["use_style_cond_and_image_meta_size"] = ( |
| 25 | args.use_style_cond_and_image_meta_size |
| 26 | ) ### version <= v1.1: True; version >= v1.2: False |
| 27 | print(model_config) |
| 28 | |
| 29 | for key in state_dict: |
| 30 | print("local:", key) |
| 31 | |
| 32 | model = HunyuanDiT2DControlNetModel.from_config(model_config).to(device) |
| 33 | |
| 34 | for key in model.state_dict(): |
| 35 | print("diffusers:", key) |
| 36 | |
| 37 | num_layers = 19 |
| 38 | for i in range(num_layers): |
| 39 | # attn1 |
| 40 | # Wkqv -> to_q, to_k, to_v |
| 41 | q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0) |
| 42 | q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0) |
| 43 | state_dict[f"blocks.{i}.attn1.to_q.weight"] = q |
| 44 | state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias |
| 45 | state_dict[f"blocks.{i}.attn1.to_k.weight"] = k |
| 46 | state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias |
| 47 | state_dict[f"blocks.{i}.attn1.to_v.weight"] = v |
| 48 | state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias |
| 49 | state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight") |
| 50 | state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias") |
| 51 | |
| 52 | # q_norm, k_norm -> norm_q, norm_k |
| 53 | state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"] |
| 54 | state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"] |
| 55 | state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"] |
| 56 | state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"] |
| 57 | |
| 58 | state_dict.pop(f"blocks.{i}.attn1.q_norm.weight") |
| 59 | state_dict.pop(f"blocks.{i}.attn1.q_norm.bias") |
| 60 | state_dict.pop(f"blocks.{i}.attn1.k_norm.weight") |
| 61 | state_dict.pop(f"blocks.{i}.attn1.k_norm.bias") |
| 62 | |
| 63 | # out_proj -> to_out |
| 64 | state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"] |
| 65 | state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"] |
no test coverage detected
searching dependent graphs…