(checkpoint_path, config_path, output_path)
| 7 | |
| 8 | |
| 9 | def convert_ldm_original(checkpoint_path, config_path, output_path): |
| 10 | config = yaml.safe_load(config_path) |
| 11 | state_dict = torch.load(checkpoint_path, map_location="cpu")["model"] |
| 12 | keys = list(state_dict.keys()) |
| 13 | |
| 14 | # extract state_dict for VQVAE |
| 15 | first_stage_dict = {} |
| 16 | first_stage_key = "first_stage_model." |
| 17 | for key in keys: |
| 18 | if key.startswith(first_stage_key): |
| 19 | first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key] |
| 20 | |
| 21 | # extract state_dict for UNetLDM |
| 22 | unet_state_dict = {} |
| 23 | unet_key = "model.diffusion_model." |
| 24 | for key in keys: |
| 25 | if key.startswith(unet_key): |
| 26 | unet_state_dict[key.replace(unet_key, "")] = state_dict[key] |
| 27 | |
| 28 | vqvae_init_args = config["model"]["params"]["first_stage_config"]["params"] |
| 29 | unet_init_args = config["model"]["params"]["unet_config"]["params"] |
| 30 | |
| 31 | vqvae = VQModel(**vqvae_init_args).eval() |
| 32 | vqvae.load_state_dict(first_stage_dict) |
| 33 | |
| 34 | unet = UNetLDMModel(**unet_init_args).eval() |
| 35 | unet.load_state_dict(unet_state_dict) |
| 36 | |
| 37 | noise_scheduler = DDIMScheduler( |
| 38 | timesteps=config["model"]["params"]["timesteps"], |
| 39 | beta_schedule="scaled_linear", |
| 40 | beta_start=config["model"]["params"]["linear_start"], |
| 41 | beta_end=config["model"]["params"]["linear_end"], |
| 42 | clip_sample=False, |
| 43 | ) |
| 44 | |
| 45 | pipeline = LDMPipeline(vqvae, unet, noise_scheduler) |
| 46 | pipeline.save_pretrained(output_path) |
| 47 | |
| 48 | |
| 49 | if __name__ == "__main__": |
no test coverage detected
searching dependent graphs…