MCPcopy Index your code
hub / github.com/huggingface/diffusers / main

Function main

scripts/convert_sana_to_diffusers.py:47–352  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

45
46
47def main(args):
48 cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub")
49
50 if args.orig_ckpt_path is None or args.orig_ckpt_path in ckpt_ids:
51 ckpt_id = args.orig_ckpt_path or ckpt_ids[0]
52 snapshot_download(
53 repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
54 cache_dir=cache_dir_path,
55 repo_type="model",
56 )
57 file_path = hf_hub_download(
58 repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
59 filename=f"{'/'.join(ckpt_id.split('/')[2:])}",
60 cache_dir=cache_dir_path,
61 repo_type="model",
62 )
63 else:
64 file_path = args.orig_ckpt_path
65
66 print(colored(f"Loading checkpoint from {file_path}", "green", attrs=["bold"]))
67 all_state_dict = torch.load(file_path, weights_only=True)
68 state_dict = all_state_dict.pop("state_dict")
69 converted_state_dict = {}
70
71 # Patch embeddings.
72 converted_state_dict["patch_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
73 converted_state_dict["patch_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
74
75 # Caption projection.
76 converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
77 converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
78 converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
79 converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
80
81 # Handle different time embedding structure based on model type
82
83 if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
84 # For Sana Sprint, the time embedding structure is different
85 converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = state_dict.pop(
86 "t_embedder.mlp.0.weight"
87 )
88 converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
89 converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = state_dict.pop(
90 "t_embedder.mlp.2.weight"
91 )
92 converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
93
94 # Guidance embedder for Sana Sprint
95 converted_state_dict["time_embed.guidance_embedder.linear_1.weight"] = state_dict.pop(
96 "cfg_embedder.mlp.0.weight"
97 )
98 converted_state_dict["time_embed.guidance_embedder.linear_1.bias"] = state_dict.pop("cfg_embedder.mlp.0.bias")
99 converted_state_dict["time_embed.guidance_embedder.linear_2.weight"] = state_dict.pop(
100 "cfg_embedder.mlp.2.weight"
101 )
102 converted_state_dict["time_embed.guidance_embedder.linear_2.bias"] = state_dict.pop("cfg_embedder.mlp.2.bias")
103 else:
104 # Original Sana time embedding structure

Callers 1

Calls 15

is_accelerate_availableFunction · 0.90
SCMSchedulerClass · 0.90
SanaSprintPipelineClass · 0.90
SanaPipelineClass · 0.90
splitMethod · 0.80
parametersMethod · 0.80
loadMethod · 0.45
popMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…