(
self,
vae,
latents,
horizontal_tiles,
vertical_tiles,
overlap,
last_frame_fix,
working_device="auto",
working_dtype="auto",
)
| 35 | CATEGORY = "latent" |
| 36 | |
| 37 | def decode( |
| 38 | self, |
| 39 | vae, |
| 40 | latents, |
| 41 | horizontal_tiles, |
| 42 | vertical_tiles, |
| 43 | overlap, |
| 44 | last_frame_fix, |
| 45 | working_device="auto", |
| 46 | working_dtype="auto", |
| 47 | ): |
| 48 | # Get the latent samples |
| 49 | samples = latents["samples"] |
| 50 | |
| 51 | if last_frame_fix: |
| 52 | # Repeat the last frame along dimension 2 (frames) |
| 53 | # samples: [batch, channels, frames, height, width] |
| 54 | last_frame = samples[ |
| 55 | :, :, -1:, :, : |
| 56 | ] # shape: [batch, channels, 1, height, width] |
| 57 | samples = torch.cat([samples, last_frame], dim=2) |
| 58 | |
| 59 | batch, channels, frames, height, width = samples.shape |
| 60 | time_scale_factor, width_scale_factor, height_scale_factor = ( |
| 61 | vae.downscale_index_formula |
| 62 | ) |
| 63 | image_frames = 1 + (frames - 1) * time_scale_factor |
| 64 | |
| 65 | # Calculate output image dimensions |
| 66 | output_height = height * height_scale_factor |
| 67 | output_width = width * width_scale_factor |
| 68 | |
| 69 | # Calculate tile sizes with overlap |
| 70 | base_tile_height = (height + (vertical_tiles - 1) * overlap) // vertical_tiles |
| 71 | base_tile_width = (width + (horizontal_tiles - 1) * overlap) // horizontal_tiles |
| 72 | |
| 73 | # Initialize output tensor and weight tensor |
| 74 | # VAE decode returns images in format [batch, height, width, channels] |
| 75 | output = None |
| 76 | weights = None |
| 77 | |
| 78 | target_device = samples.device if working_device == "auto" else working_device |
| 79 | if working_dtype == "auto": |
| 80 | target_dtype = samples.dtype |
| 81 | elif working_dtype == "float16": |
| 82 | target_dtype = torch.float16 |
| 83 | elif working_dtype == "float32": |
| 84 | target_dtype = torch.float32 |
| 85 | |
| 86 | output = torch.zeros( |
| 87 | ( |
| 88 | batch, |
| 89 | image_frames, |
| 90 | output_height, |
| 91 | output_width, |
| 92 | 3, |
| 93 | ), |
| 94 | device=target_device, |
no outgoing calls
no test coverage detected