(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[])
| 83 | |
| 84 | |
| 85 | def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]): |
| 86 | # Main models |
| 87 | self.text_encoder = model_manager.fetch_model("sd_text_encoder") |
| 88 | self.unet = model_manager.fetch_model("sd_unet") |
| 89 | self.vae_decoder = model_manager.fetch_model("sd_vae_decoder") |
| 90 | self.vae_encoder = model_manager.fetch_model("sd_vae_encoder") |
| 91 | self.prompter.fetch_models(self.text_encoder) |
| 92 | self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes) |
| 93 | |
| 94 | # ControlNets |
| 95 | controlnet_units = [] |
| 96 | for config in controlnet_config_units: |
| 97 | controlnet_unit = ControlNetUnit( |
| 98 | Annotator(config.processor_id, device=self.device), |
| 99 | model_manager.fetch_model("sd_controlnet", config.model_path), |
| 100 | config.scale |
| 101 | ) |
| 102 | controlnet_units.append(controlnet_unit) |
| 103 | self.controlnet = MultiControlNetManager(controlnet_units) |
| 104 | |
| 105 | # IP-Adapters |
| 106 | self.ipadapter = model_manager.fetch_model("sd_ipadapter") |
| 107 | self.ipadapter_image_encoder = model_manager.fetch_model("sd_ipadapter_clip_image_encoder") |
| 108 | |
| 109 | # Motion Modules |
| 110 | self.motion_modules = model_manager.fetch_model("sd_motion_modules") |
| 111 | if self.motion_modules is None: |
| 112 | self.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear") |
| 113 | |
| 114 | |
| 115 | @staticmethod |
no test coverage detected