| 1819 | return (vace_input,) |
| 1820 | |
| 1821 | def vace_encode_frames(self, vae, frames, ref_images, masks=None, tiled_vae=False): |
| 1822 | if ref_images is None: |
| 1823 | ref_images = [None] * len(frames) |
| 1824 | else: |
| 1825 | assert len(frames) == len(ref_images) |
| 1826 | |
| 1827 | pbar = ProgressBar(len(frames)) |
| 1828 | if masks is None: |
| 1829 | latents = vae.encode(frames, device=device, tiled=tiled_vae) |
| 1830 | else: |
| 1831 | inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] |
| 1832 | reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] |
| 1833 | del frames |
| 1834 | inactive = vae.encode(inactive, device=device, tiled=tiled_vae) |
| 1835 | reactive = vae.encode(reactive, device=device, tiled=tiled_vae) |
| 1836 | latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] |
| 1837 | del inactive, reactive |
| 1838 | |
| 1839 | |
| 1840 | cat_latents = [] |
| 1841 | for latent, refs in zip(latents, ref_images): |
| 1842 | if refs is not None: |
| 1843 | if masks is None: |
| 1844 | ref_latent = vae.encode(refs, device=device, tiled=tiled_vae) |
| 1845 | else: |
| 1846 | ref_latent = vae.encode(refs, device=device, tiled=tiled_vae) |
| 1847 | ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] |
| 1848 | assert all([x.shape[1] == 1 for x in ref_latent]) |
| 1849 | latent = torch.cat([*ref_latent, latent], dim=1) |
| 1850 | cat_latents.append(latent) |
| 1851 | pbar.update(1) |
| 1852 | return cat_latents |
| 1853 | |
| 1854 | def vace_encode_masks(self, masks, ref_images=None): |
| 1855 | if ref_images is None: |