(self)
| 46 | def load_cleaner(self) -> None: ... |
| 47 | |
| 48 | def load_cldm(self) -> None: |
| 49 | self.cldm: ControlLDM = instantiate_from_config( |
| 50 | OmegaConf.load("configs/inference/cldm.yaml") |
| 51 | ) |
| 52 | |
| 53 | # load pre-trained SD weight |
| 54 | if self.args.version == "v2.1": |
| 55 | sd_weight = load_model_from_url(MODELS["sd_v2.1_zsnr"]) |
| 56 | else: |
| 57 | # v1, v2 |
| 58 | sd_weight = load_model_from_url(MODELS["sd_v2.1"]) |
| 59 | unused, missing = self.cldm.load_pretrained_sd(sd_weight) |
| 60 | print( |
| 61 | f"load pretrained stable diffusion, " |
| 62 | f"unused weights: {unused}, missing weights: {missing}" |
| 63 | ) |
| 64 | # load controlnet weight |
| 65 | if self.args.version == "v1": |
| 66 | if self.args.task == "face": |
| 67 | control_weight = load_model_from_url(MODELS["v1_face"]) |
| 68 | elif self.args.task == "sr" or self.args.task == "denoise": |
| 69 | control_weight = load_model_from_url(MODELS["v1_general"]) |
| 70 | else: |
| 71 | raise ValueError( |
| 72 | f"DiffBIR v1 doesn't support task: {self.args.task}, " |
| 73 | f"please use v2 or v2.1 by passsing '--version'" |
| 74 | ) |
| 75 | elif self.args.version == "v2": |
| 76 | control_weight = load_model_from_url(MODELS["v2"]) |
| 77 | else: |
| 78 | # v2.1 |
| 79 | control_weight = load_model_from_url(MODELS["v2.1"]) |
| 80 | self.cldm.load_controlnet_from_ckpt(control_weight) |
| 81 | print(f"load controlnet weight") |
| 82 | self.cldm.eval().to(self.args.device) |
| 83 | cast_type = { |
| 84 | "fp32": torch.float32, |
| 85 | "fp16": torch.float16, |
| 86 | "bf16": torch.bfloat16, |
| 87 | }[self.args.precision] |
| 88 | self.cldm.cast_dtype(cast_type) |
| 89 | |
| 90 | # load diffusion |
| 91 | if self.args.version in ["v1", "v2"]: |
| 92 | config = "configs/inference/diffusion.yaml" |
| 93 | else: |
| 94 | config = "configs/inference/diffusion_v2.1.yaml" |
| 95 | self.diffusion: Diffusion = instantiate_from_config(OmegaConf.load(config)) |
| 96 | self.diffusion.to(self.args.device) |
| 97 | |
| 98 | def load_cond_fn(self) -> None: |
| 99 | if not self.args.guidance: |
no test coverage detected