(self)
| 18 | class BIDInferenceLoop(InferenceLoop): |
| 19 | |
| 20 | def load_cleaner(self) -> None: |
| 21 | if self.args.version == "v1": |
| 22 | config = "configs/inference/swinir.yaml" |
| 23 | weight = MODELS["swinir_general"] |
| 24 | elif self.args.version == "v2": |
| 25 | config = "configs/inference/scunet.yaml" |
| 26 | weight = MODELS["scunet_psnr"] |
| 27 | else: |
| 28 | config = "configs/inference/swinir.yaml" |
| 29 | weight = MODELS["swinir_realesrgan"] |
| 30 | self.cleaner: SCUNet | SwinIR = instantiate_from_config(OmegaConf.load(config)) |
| 31 | model_weight = load_model_from_url(weight) |
| 32 | self.cleaner.load_state_dict(model_weight, strict=True) |
| 33 | self.cleaner.eval().to(self.args.device) |
| 34 | |
| 35 | def load_pipeline(self) -> None: |
| 36 | if self.args.version == "v1" or self.args.version == "v2.1": |
nothing calls this directly
no test coverage detected