(
cls,
positive,
negative,
vae,
latent,
image,
frame_idx,
strength,
latent_downscale_factor,
crop,
use_tiled_encode,
tile_size,
tile_overlap,
)
| 146 | |
| 147 | @classmethod |
| 148 | def execute( |
| 149 | cls, |
| 150 | positive, |
| 151 | negative, |
| 152 | vae, |
| 153 | latent, |
| 154 | image, |
| 155 | frame_idx, |
| 156 | strength, |
| 157 | latent_downscale_factor, |
| 158 | crop, |
| 159 | use_tiled_encode, |
| 160 | tile_size, |
| 161 | tile_overlap, |
| 162 | ) -> io.NodeOutput: |
| 163 | scale_factors = vae.downscale_index_formula |
| 164 | latent_image = latent["samples"] |
| 165 | noise_mask = nodes_lt.get_noise_mask(latent) |
| 166 | |
| 167 | _, _, latent_length, latent_height, latent_width = latent_image.shape |
| 168 | |
| 169 | time_scale_factor = scale_factors[0] |
| 170 | num_frames_to_keep = ( |
| 171 | (image.shape[0] - 1) // time_scale_factor |
| 172 | ) * time_scale_factor + 1 |
| 173 | causal_fix = frame_idx == 0 or num_frames_to_keep == 1 |
| 174 | if not causal_fix: |
| 175 | image = torch.cat([image[:1], image], dim=0) |
| 176 | |
| 177 | image, guide_latent = cls.encode( |
| 178 | vae, |
| 179 | latent_width, |
| 180 | latent_height, |
| 181 | image, |
| 182 | scale_factors, |
| 183 | latent_downscale_factor, |
| 184 | crop, |
| 185 | use_tiled_encode, |
| 186 | tile_size, |
| 187 | tile_overlap, |
| 188 | ) |
| 189 | |
| 190 | if not causal_fix: |
| 191 | guide_latent = guide_latent[:, :, 1:, :, :] |
| 192 | image = image[1:] |
| 193 | |
| 194 | # Record original (pre-dilation) guide latent shape for spatial mask downsampling |
| 195 | guide_orig_shape = list(guide_latent.shape[2:]) # [F, H_small, W_small] |
| 196 | |
| 197 | guide_mask = None |
| 198 | |
| 199 | # Dilate the latent if latent_downscale_factor > 1 |
| 200 | if latent_downscale_factor > 1: |
| 201 | if ( |
| 202 | latent_width % latent_downscale_factor != 0 |
| 203 | or latent_height % latent_downscale_factor != 0 |
| 204 | ): |
| 205 | raise ValueError( |
nothing calls this directly
no test coverage detected