| 1852 | return cat_latents |
| 1853 | |
| 1854 | def vace_encode_masks(self, masks, ref_images=None): |
| 1855 | if ref_images is None: |
| 1856 | ref_images = [None] * len(masks) |
| 1857 | else: |
| 1858 | assert len(masks) == len(ref_images) |
| 1859 | |
| 1860 | result_masks = [] |
| 1861 | pbar = ProgressBar(len(masks)) |
| 1862 | for mask, refs in zip(masks, ref_images): |
| 1863 | _c, depth, height, width = mask.shape |
| 1864 | new_depth = int((depth + 3) // VAE_STRIDE[0]) |
| 1865 | height = 2 * (int(height) // (VAE_STRIDE[1] * 2)) |
| 1866 | width = 2 * (int(width) // (VAE_STRIDE[2] * 2)) |
| 1867 | |
| 1868 | # reshape |
| 1869 | mask = mask[0, :, :, :] |
| 1870 | mask = mask.view( |
| 1871 | depth, height, VAE_STRIDE[1], width, VAE_STRIDE[1] |
| 1872 | ) # depth, height, 8, width, 8 |
| 1873 | mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width |
| 1874 | mask = mask.reshape( |
| 1875 | VAE_STRIDE[1] * VAE_STRIDE[2], depth, height, width |
| 1876 | ) # 8*8, depth, height, width |
| 1877 | |
| 1878 | # interpolation |
| 1879 | mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) |
| 1880 | |
| 1881 | if refs is not None: |
| 1882 | length = len(refs) |
| 1883 | mask_pad = torch.zeros_like(mask[:, :length, :, :]) |
| 1884 | mask = torch.cat((mask_pad, mask), dim=1) |
| 1885 | result_masks.append(mask) |
| 1886 | pbar.update(1) |
| 1887 | return result_masks |
| 1888 | |
| 1889 | def vace_latent(self, z, m): |
| 1890 | return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] |