MCPcopy Index your code
hub / github.com/huggingface/diffusers / convert_ldm_original

Function convert_ldm_original

scripts/conversion_ldm_uncond.py:9–46  ·  view source on GitHub ↗
(checkpoint_path, config_path, output_path)

Source from the content-addressed store, hash-verified

7
8
9def convert_ldm_original(checkpoint_path, config_path, output_path):
10 config = yaml.safe_load(config_path)
11 state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
12 keys = list(state_dict.keys())
13
14 # extract state_dict for VQVAE
15 first_stage_dict = {}
16 first_stage_key = "first_stage_model."
17 for key in keys:
18 if key.startswith(first_stage_key):
19 first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key]
20
21 # extract state_dict for UNetLDM
22 unet_state_dict = {}
23 unet_key = "model.diffusion_model."
24 for key in keys:
25 if key.startswith(unet_key):
26 unet_state_dict[key.replace(unet_key, "")] = state_dict[key]
27
28 vqvae_init_args = config["model"]["params"]["first_stage_config"]["params"]
29 unet_init_args = config["model"]["params"]["unet_config"]["params"]
30
31 vqvae = VQModel(**vqvae_init_args).eval()
32 vqvae.load_state_dict(first_stage_dict)
33
34 unet = UNetLDMModel(**unet_init_args).eval()
35 unet.load_state_dict(unet_state_dict)
36
37 noise_scheduler = DDIMScheduler(
38 timesteps=config["model"]["params"]["timesteps"],
39 beta_schedule="scaled_linear",
40 beta_start=config["model"]["params"]["linear_start"],
41 beta_end=config["model"]["params"]["linear_end"],
42 clip_sample=False,
43 )
44
45 pipeline = LDMPipeline(vqvae, unet, noise_scheduler)
46 pipeline.save_pretrained(output_path)
47
48
49if __name__ == "__main__":

Callers 1

Calls 6

VQModelClass · 0.90
DDIMSchedulerClass · 0.90
LDMPipelineClass · 0.90
loadMethod · 0.45
load_state_dictMethod · 0.45
save_pretrainedMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…