(model)
| 222 | @torch.no_grad() |
| 223 | @torch.inference_mode() |
| 224 | def get_previewer(model): |
| 225 | global VAE_approx_models |
| 226 | |
| 227 | from modules.config import path_vae_approx |
| 228 | is_sdxl = isinstance(model.model.latent_format, ldm_patched.modules.latent_formats.SDXL) |
| 229 | vae_approx_filename = os.path.join(path_vae_approx, 'xlvaeapp.pth' if is_sdxl else 'vaeapp_sd15.pth') |
| 230 | |
| 231 | if vae_approx_filename in VAE_approx_models: |
| 232 | VAE_approx_model = VAE_approx_models[vae_approx_filename] |
| 233 | else: |
| 234 | sd = torch.load(vae_approx_filename, map_location='cpu', weights_only=True) |
| 235 | VAE_approx_model = VAEApprox() |
| 236 | VAE_approx_model.load_state_dict(sd) |
| 237 | del sd |
| 238 | VAE_approx_model.eval() |
| 239 | |
| 240 | if ldm_patched.modules.model_management.should_use_fp16(): |
| 241 | VAE_approx_model.half() |
| 242 | VAE_approx_model.current_type = torch.float16 |
| 243 | else: |
| 244 | VAE_approx_model.float() |
| 245 | VAE_approx_model.current_type = torch.float32 |
| 246 | |
| 247 | VAE_approx_model.to(ldm_patched.modules.model_management.get_torch_device()) |
| 248 | VAE_approx_models[vae_approx_filename] = VAE_approx_model |
| 249 | |
| 250 | @torch.no_grad() |
| 251 | @torch.inference_mode() |
| 252 | def preview_function(x0, step, total_steps): |
| 253 | with torch.no_grad(): |
| 254 | x_sample = x0.to(VAE_approx_model.current_type) |
| 255 | x_sample = VAE_approx_model(x_sample) * 127.5 + 127.5 |
| 256 | x_sample = einops.rearrange(x_sample, 'b c h w -> b h w c')[0] |
| 257 | x_sample = x_sample.cpu().numpy().clip(0, 255).astype(np.uint8) |
| 258 | return x_sample |
| 259 | |
| 260 | return preview_function |
| 261 | |
| 262 | |
| 263 | @torch.no_grad() |
no test coverage detected