MCPcopy Index your code
hub / github.com/huggingface/diffusers / main

Function main

scripts/convert_sana_video_to_diffusers.py:35–295  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

33
34
35def main(args):
36 cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub")
37
38 if args.orig_ckpt_path is None or args.orig_ckpt_path in ckpt_ids:
39 ckpt_id = args.orig_ckpt_path or ckpt_ids[0]
40 snapshot_download(
41 repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
42 cache_dir=cache_dir_path,
43 repo_type="model",
44 )
45 file_path = hf_hub_download(
46 repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
47 filename=f"{'/'.join(ckpt_id.split('/')[2:])}",
48 cache_dir=cache_dir_path,
49 repo_type="model",
50 )
51 else:
52 file_path = args.orig_ckpt_path
53
54 print(colored(f"Loading checkpoint from {file_path}", "green", attrs=["bold"]))
55 all_state_dict = torch.load(file_path, weights_only=True)
56 state_dict = all_state_dict.pop("state_dict")
57 converted_state_dict = {}
58
59 # Patch embeddings.
60 converted_state_dict["patch_embedding.weight"] = state_dict.pop("x_embedder.proj.weight")
61 converted_state_dict["patch_embedding.bias"] = state_dict.pop("x_embedder.proj.bias")
62
63 # Caption projection.
64 converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
65 converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
66 converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
67 converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
68
69 converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
70 "t_embedder.mlp.0.weight"
71 )
72 converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
73 converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
74 "t_embedder.mlp.2.weight"
75 )
76 converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
77
78 # Shared norm.
79 converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
80 converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias")
81
82 # y norm
83 converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
84
85 # scheduler
86 flow_shift = 8.0
87 if args.task == "i2v":
88 assert args.scheduler_type == "flow-euler", "Scheduler type must be flow-euler for i2v task."
89
90 # model config
91 layer_num = 20
92 # Positional embedding interpolation scale.

Calls 13

SanaVideoPipelineClass · 0.90
splitMethod · 0.80
parametersMethod · 0.80
loadMethod · 0.45
popMethod · 0.45
load_state_dictMethod · 0.45
toMethod · 0.45
save_pretrainedMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…