MCPcopy
hub / github.com/XPixelGroup/DiffBIR / load_cldm

Method load_cldm

diffbir/inference/loop.py:48–96  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

46 def load_cleaner(self) -> None: ...
47
48 def load_cldm(self) -> None:
49 self.cldm: ControlLDM = instantiate_from_config(
50 OmegaConf.load("configs/inference/cldm.yaml")
51 )
52
53 # load pre-trained SD weight
54 if self.args.version == "v2.1":
55 sd_weight = load_model_from_url(MODELS["sd_v2.1_zsnr"])
56 else:
57 # v1, v2
58 sd_weight = load_model_from_url(MODELS["sd_v2.1"])
59 unused, missing = self.cldm.load_pretrained_sd(sd_weight)
60 print(
61 f"load pretrained stable diffusion, "
62 f"unused weights: {unused}, missing weights: {missing}"
63 )
64 # load controlnet weight
65 if self.args.version == "v1":
66 if self.args.task == "face":
67 control_weight = load_model_from_url(MODELS["v1_face"])
68 elif self.args.task == "sr" or self.args.task == "denoise":
69 control_weight = load_model_from_url(MODELS["v1_general"])
70 else:
71 raise ValueError(
72 f"DiffBIR v1 doesn't support task: {self.args.task}, "
73 f"please use v2 or v2.1 by passsing '--version'"
74 )
75 elif self.args.version == "v2":
76 control_weight = load_model_from_url(MODELS["v2"])
77 else:
78 # v2.1
79 control_weight = load_model_from_url(MODELS["v2.1"])
80 self.cldm.load_controlnet_from_ckpt(control_weight)
81 print(f"load controlnet weight")
82 self.cldm.eval().to(self.args.device)
83 cast_type = {
84 "fp32": torch.float32,
85 "fp16": torch.float16,
86 "bf16": torch.bfloat16,
87 }[self.args.precision]
88 self.cldm.cast_dtype(cast_type)
89
90 # load diffusion
91 if self.args.version in ["v1", "v2"]:
92 config = "configs/inference/diffusion.yaml"
93 else:
94 config = "configs/inference/diffusion_v2.1.yaml"
95 self.diffusion: Diffusion = instantiate_from_config(OmegaConf.load(config))
96 self.diffusion.to(self.args.device)
97
98 def load_cond_fn(self) -> None:
99 if not self.args.guidance:

Callers 1

__init__Method · 0.95

Calls 5

instantiate_from_configFunction · 0.85
load_model_from_urlFunction · 0.85
load_pretrained_sdMethod · 0.80
cast_dtypeMethod · 0.80

Tested by

no test coverage detected