| 137 | keep_last=True) |
| 138 | |
| 139 | def vace_encode_frames(self, frames, ref_images, masks=None, vae=None): |
| 140 | vae = self.vae if vae is None else vae |
| 141 | if ref_images is None: |
| 142 | ref_images = [None] * len(frames) |
| 143 | else: |
| 144 | assert len(frames) == len(ref_images) |
| 145 | |
| 146 | if masks is None: |
| 147 | latents = vae.encode(frames) |
| 148 | else: |
| 149 | masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks] |
| 150 | inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] |
| 151 | reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] |
| 152 | inactive = vae.encode(inactive) |
| 153 | reactive = vae.encode(reactive) |
| 154 | latents = [ |
| 155 | torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive) |
| 156 | ] |
| 157 | |
| 158 | cat_latents = [] |
| 159 | for latent, refs in zip(latents, ref_images): |
| 160 | if refs is not None: |
| 161 | if masks is None: |
| 162 | ref_latent = vae.encode(refs) |
| 163 | else: |
| 164 | ref_latent = vae.encode(refs) |
| 165 | ref_latent = [ |
| 166 | torch.cat((u, torch.zeros_like(u)), dim=0) |
| 167 | for u in ref_latent |
| 168 | ] |
| 169 | assert all([x.shape[1] == 1 for x in ref_latent]) |
| 170 | latent = torch.cat([*ref_latent, latent], dim=1) |
| 171 | cat_latents.append(latent) |
| 172 | return cat_latents |
| 173 | |
| 174 | def vace_encode_masks(self, masks, ref_images=None, vae_stride=None): |
| 175 | vae_stride = self.vae_stride if vae_stride is None else vae_stride |