(x0, step, total_steps)
| 250 | @torch.no_grad() |
| 251 | @torch.inference_mode() |
| 252 | def preview_function(x0, step, total_steps): |
| 253 | with torch.no_grad(): |
| 254 | x_sample = x0.to(VAE_approx_model.current_type) |
| 255 | x_sample = VAE_approx_model(x_sample) * 127.5 + 127.5 |
| 256 | x_sample = einops.rearrange(x_sample, 'b c h w -> b h w c')[0] |
| 257 | x_sample = x_sample.cpu().numpy().clip(0, 255).astype(np.uint8) |
| 258 | return x_sample |
| 259 | |
| 260 | return preview_function |
| 261 |