(image, model)
| 182 | |
| 183 | |
| 184 | def full_vae_encode(image, model): |
| 185 | t0 = time.time() |
| 186 | if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False) and hasattr(model, 'unet'): |
| 187 | log_debug('Moving to CPU: model=UNet') |
| 188 | unet_device = model.unet.device |
| 189 | sd_models.move_model(model.unet, devices.cpu) |
| 190 | if shared.opts.diffusers_offload_mode != "sequential" and hasattr(model, 'vae'): |
| 191 | sd_models.move_model(model.vae, devices.device) |
| 192 | vae_name = sd_vae.loaded_vae_file if sd_vae.loaded_vae_file is not None else "default" |
| 193 | log_debug(f'Encode vae="{vae_name}" dtype={model.vae.dtype} upcast={model.vae.config.get("force_upcast", None)}') |
| 194 | |
| 195 | sd_models.set_vae_options(model, vae=None, op='encode') |
| 196 | upcast = (model.vae.dtype == torch.float16) and (getattr(model.vae.config, 'force_upcast', False) or shared.opts.no_half_vae) |
| 197 | if upcast: |
| 198 | if hasattr(model, 'upcast_vae'): # this is done by diffusers automatically if output_type != 'latent' |
| 199 | model.upcast_vae() |
| 200 | else: # manual upcast and we restore it later |
| 201 | model.vae.orig_dtype = model.vae.dtype |
| 202 | model.vae = model.vae.to(dtype=torch.float32) |
| 203 | |
| 204 | encoded = model.vae.encode(image.to(model.vae.device, model.vae.dtype)).latent_dist.sample() |
| 205 | |
| 206 | if hasattr(model.vae, "orig_dtype"): |
| 207 | model.vae = model.vae.to(dtype=model.vae.orig_dtype) |
| 208 | del model.vae.orig_dtype |
| 209 | |
| 210 | if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False) and hasattr(model, 'unet'): |
| 211 | sd_models.move_model(model.unet, unet_device) |
| 212 | t1 = time.time() |
| 213 | log.debug(f'Encode: vae="{vae_name}" upcast={upcast} slicing={getattr(model.vae, "use_slicing", None)} tiling={getattr(model.vae, "use_tiling", None)} latents={encoded.shape}:{encoded.device}:{encoded.dtype} time={t1-t0:.3f}') |
| 214 | return encoded |
| 215 | |
| 216 | |
| 217 | def taesd_vae_decode(latents): |
no test coverage detected