Unpack packed latents to standard [B, C, H, W] format for correction.
(latents, pipe, p)
| 145 | |
| 146 | |
| 147 | def _unpack_latents(latents, pipe, p): |
| 148 | """Unpack packed latents to standard [B, C, H, W] format for correction.""" |
| 149 | vae_scale = getattr(pipe, 'vae_scale_factor', 8) |
| 150 | if p.hr_resize_mode > 0 and (p.hr_upscaler != 'None' or p.hr_resize_mode == 5) and p.is_hr_pass: |
| 151 | width = max(getattr(p, 'width', 0), getattr(p, 'hr_upscale_to_x', 0)) |
| 152 | height = max(getattr(p, 'height', 0), getattr(p, 'hr_upscale_to_y', 0)) |
| 153 | else: |
| 154 | width = getattr(p, 'width', 1024) |
| 155 | height = getattr(p, 'height', 1024) |
| 156 | if hasattr(pipe, '_unpack_latents') and hasattr(pipe, 'vae_scale_factor'): |
| 157 | # Flux 1 / Bria: use pipeline's own unpack method |
| 158 | unpacked = pipe._unpack_latents(latents, height, width, vae_scale) # pylint: disable=protected-access |
| 159 | return unpacked, 'flux1' |
| 160 | if hasattr(pipe, '_unpatchify_latents'): |
| 161 | # Flux 2: manual reshape [B, seq_len, patch_channels] -> [B, C, H, W] |
| 162 | b, seq_len, patch_ch = latents.shape |
| 163 | channels = patch_ch // 4 |
| 164 | h_patches = height // vae_scale // 2 |
| 165 | w_patches = width // vae_scale // 2 |
| 166 | if h_patches * w_patches != seq_len: |
| 167 | h_patches = w_patches = int(seq_len ** 0.5) |
| 168 | unpacked = latents.view(b, h_patches, w_patches, channels, 2, 2) |
| 169 | unpacked = unpacked.permute(0, 3, 1, 4, 2, 5).reshape(b, channels, h_patches * 2, w_patches * 2) |
| 170 | return unpacked, 'flux2' |
| 171 | return latents, 'unknown' |
| 172 | |
| 173 | |
| 174 | def _repack_latents(latents, pack_type, pipe, p): |
no test coverage detected