(latents, model)
| 87 | |
| 88 | |
| 89 | def full_vae_decode(latents, model): |
| 90 | t0 = time.time() |
| 91 | if not hasattr(model, 'vae') and hasattr(model, 'pipe'): |
| 92 | model = model.pipe |
| 93 | if model is None or not hasattr(model, 'vae'): |
| 94 | log.error('VAE not found in model') |
| 95 | return [] |
| 96 | if debug: |
| 97 | devices.torch_gc(force=True) |
| 98 | shared.mem_mon.reset() |
| 99 | |
| 100 | base_device = None |
| 101 | if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False): |
| 102 | base_device = sd_models.move_base(model, devices.cpu) |
| 103 | elif shared.opts.diffusers_offload_mode != "sequential": |
| 104 | sd_models.move_model(model.vae, devices.device) |
| 105 | |
| 106 | sd_models.set_vae_options(model, vae=None, op='decode') |
| 107 | upcast = (model.vae.dtype == torch.float16) and (getattr(model.vae.config, 'force_upcast', False) or shared.opts.no_half_vae) |
| 108 | if upcast: |
| 109 | if hasattr(model, 'upcast_vae'): # this is done by diffusers automatically if output_type != 'latent' |
| 110 | model.upcast_vae() |
| 111 | else: # manual upcast and we restore it later |
| 112 | model.vae.orig_dtype = model.vae.dtype |
| 113 | model.vae = model.vae.to(dtype=torch.float32) |
| 114 | latents = latents.to(devices.device) |
| 115 | |
| 116 | # normalize latents |
| 117 | latents_mean = model.vae.config.get("latents_mean", None) |
| 118 | latents_std = model.vae.config.get("latents_std", None) |
| 119 | scaling_factor = model.vae.config.get("scaling_factor", 1.0) |
| 120 | shift_factor = model.vae.config.get("shift_factor", None) |
| 121 | if latents_mean and latents_std: |
| 122 | broadcast_shape = [1 for _ in range(latents.ndim)] |
| 123 | broadcast_shape[1] = -1 |
| 124 | latents_mean = (torch.tensor(latents_mean).view(*broadcast_shape).to(latents.device, latents.dtype)) |
| 125 | latents_std = (torch.tensor(latents_std).view(*broadcast_shape).to(latents.device, latents.dtype)) |
| 126 | latents = ((latents * latents_std) / scaling_factor) + latents_mean |
| 127 | else: |
| 128 | latents = latents / scaling_factor |
| 129 | if shift_factor: |
| 130 | latents = latents + shift_factor |
| 131 | |
| 132 | # check dims |
| 133 | if model.vae.__class__.__name__ in ['AutoencoderKLWan'] and latents.ndim == 4: |
| 134 | latents = latents.unsqueeze(2) # wan is __nhw |
| 135 | |
| 136 | # handle quants |
| 137 | if getattr(model.vae, "post_quant_conv", None) is not None: |
| 138 | if getattr(model.vae.post_quant_conv, "bias", None) is not None: |
| 139 | latents = latents.to(model.vae.post_quant_conv.bias.dtype) |
| 140 | elif "VAE" in shared.opts.sdnq_quantize_weights: |
| 141 | latents = latents.to(devices.dtype_vae) |
| 142 | else: |
| 143 | latents = latents.to(next(iter(model.vae.post_quant_conv.parameters())).dtype) |
| 144 | # if getattr(model.vae.post_quant_conv, "bias", None) is not None: |
| 145 | # model.vae.post_quant_conv.bias = torch.nn.Parameter(model.vae.post_quant_conv.bias.to(devices.device), requires_grad=False) |
| 146 | # if getattr(model.vae.post_quant_conv, "weight", None) is not None: |
no test coverage detected