(self, sd_unet_option)
| 265 | self.apply_unet(sd_unet_option) |
| 266 | |
| 267 | def apply_unet(self, sd_unet_option): |
| 268 | if ( |
| 269 | sd_unet_option == sd_unet.current_unet_option |
| 270 | and sd_unet.current_unet is not None |
| 271 | and not self.torch_unet |
| 272 | ): |
| 273 | return |
| 274 | |
| 275 | if sd_unet.current_unet is not None: |
| 276 | sd_unet.current_unet.deactivate() |
| 277 | |
| 278 | if self.torch_unet: |
| 279 | gr.Warning("Enabling PyTorch fallback as no engine was found.") |
| 280 | sd_unet.current_unet = None |
| 281 | sd_unet.current_unet_option = sd_unet_option |
| 282 | shared.sd_model.model.diffusion_model.to(devices.device) |
| 283 | return |
| 284 | else: |
| 285 | shared.sd_model.model.diffusion_model.to(devices.cpu) |
| 286 | devices.torch_gc() |
| 287 | if self.lora_refit_dict: |
| 288 | self.update_lora = True |
| 289 | sd_unet.current_unet = sd_unet_option.create_unet() |
| 290 | sd_unet.current_unet.profile_idx = self.idx |
| 291 | sd_unet.current_unet.option = sd_unet_option |
| 292 | sd_unet.current_unet_option = sd_unet_option |
| 293 | |
| 294 | print(f"Activating unet: {sd_unet.current_unet.option.label}") |
| 295 | sd_unet.current_unet.activate() |
| 296 | |
| 297 | def process_batch(self, p, *args, **kwargs): |
| 298 | # Called for each batch count |
no test coverage detected