MCPcopy
hub / github.com/vladmandic/sdnext / full_vae_decode

Function full_vae_decode

modules/processing_vae.py:89–181  ·  view source on GitHub ↗
(latents, model)

Source from the content-addressed store, hash-verified

87
88
89def full_vae_decode(latents, model):
90 t0 = time.time()
91 if not hasattr(model, 'vae') and hasattr(model, 'pipe'):
92 model = model.pipe
93 if model is None or not hasattr(model, 'vae'):
94 log.error('VAE not found in model')
95 return []
96 if debug:
97 devices.torch_gc(force=True)
98 shared.mem_mon.reset()
99
100 base_device = None
101 if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False):
102 base_device = sd_models.move_base(model, devices.cpu)
103 elif shared.opts.diffusers_offload_mode != "sequential":
104 sd_models.move_model(model.vae, devices.device)
105
106 sd_models.set_vae_options(model, vae=None, op='decode')
107 upcast = (model.vae.dtype == torch.float16) and (getattr(model.vae.config, 'force_upcast', False) or shared.opts.no_half_vae)
108 if upcast:
109 if hasattr(model, 'upcast_vae'): # this is done by diffusers automatically if output_type != 'latent'
110 model.upcast_vae()
111 else: # manual upcast and we restore it later
112 model.vae.orig_dtype = model.vae.dtype
113 model.vae = model.vae.to(dtype=torch.float32)
114 latents = latents.to(devices.device)
115
116 # normalize latents
117 latents_mean = model.vae.config.get("latents_mean", None)
118 latents_std = model.vae.config.get("latents_std", None)
119 scaling_factor = model.vae.config.get("scaling_factor", 1.0)
120 shift_factor = model.vae.config.get("shift_factor", None)
121 if latents_mean and latents_std:
122 broadcast_shape = [1 for _ in range(latents.ndim)]
123 broadcast_shape[1] = -1
124 latents_mean = (torch.tensor(latents_mean).view(*broadcast_shape).to(latents.device, latents.dtype))
125 latents_std = (torch.tensor(latents_std).view(*broadcast_shape).to(latents.device, latents.dtype))
126 latents = ((latents * latents_std) / scaling_factor) + latents_mean
127 else:
128 latents = latents / scaling_factor
129 if shift_factor:
130 latents = latents + shift_factor
131
132 # check dims
133 if model.vae.__class__.__name__ in ['AutoencoderKLWan'] and latents.ndim == 4:
134 latents = latents.unsqueeze(2) # wan is __nhw
135
136 # handle quants
137 if getattr(model.vae, "post_quant_conv", None) is not None:
138 if getattr(model.vae.post_quant_conv, "bias", None) is not None:
139 latents = latents.to(model.vae.post_quant_conv.bias.dtype)
140 elif "VAE" in shared.opts.sdnq_quantize_weights:
141 latents = latents.to(devices.dtype_vae)
142 else:
143 latents = latents.to(next(iter(model.vae.post_quant_conv.parameters())).dtype)
144 # if getattr(model.vae.post_quant_conv, "bias", None) is not None:
145 # model.vae.post_quant_conv.bias = torch.nn.Parameter(model.vae.post_quant_conv.bias.to(devices.device), requires_grad=False)
146 # if getattr(model.vae.post_quant_conv, "weight", None) is not None:

Callers 1

vae_decodeFunction · 0.85

Calls 9

viewMethod · 0.80
resetMethod · 0.45
upcast_vaeMethod · 0.45
toMethod · 0.45
getMethod · 0.45
decodeMethod · 0.45
displayMethod · 0.45
applyMethod · 0.45
readMethod · 0.45

Tested by

no test coverage detected