(self, vae, width, height, num_frames, strength, vace_start_percent, vace_end_percent, input_frames=None, ref_images=None, input_masks=None, prev_vace_embeds=None, tiled_vae=False)
| 1741 | CATEGORY = "WanVideoWrapper" |
| 1742 | |
| 1743 | def process(self, vae, width, height, num_frames, strength, vace_start_percent, vace_end_percent, input_frames=None, ref_images=None, input_masks=None, prev_vace_embeds=None, tiled_vae=False): |
| 1744 | width = (width // 16) * 16 |
| 1745 | height = (height // 16) * 16 |
| 1746 | |
| 1747 | target_shape = (16, (num_frames - 1) // VAE_STRIDE[0] + 1, |
| 1748 | height // VAE_STRIDE[1], |
| 1749 | width // VAE_STRIDE[2]) |
| 1750 | # vace context encode |
| 1751 | if input_frames is None: |
| 1752 | input_frames = torch.zeros((1, 3, num_frames, height, width), device=device, dtype=vae.dtype) |
| 1753 | else: |
| 1754 | input_frames = input_frames.clone()[:num_frames, :, :, :3] |
| 1755 | input_frames = common_upscale(input_frames.movedim(-1, 1), width, height, "lanczos", "disabled").movedim(1, -1) |
| 1756 | input_frames = input_frames.to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W |
| 1757 | input_frames = input_frames * 2 - 1 |
| 1758 | if input_masks is None: |
| 1759 | input_masks = torch.ones_like(input_frames, device=device) |
| 1760 | else: |
| 1761 | log.info(f"input_masks shape: {input_masks.shape}") |
| 1762 | input_masks = input_masks[:num_frames] |
| 1763 | input_masks = common_upscale(input_masks.clone().unsqueeze(1), width, height, "nearest-exact", "disabled").squeeze(1) |
| 1764 | input_masks = input_masks.to(vae.dtype).to(device) |
| 1765 | input_masks = input_masks.unsqueeze(-1).unsqueeze(0).permute(0, 4, 1, 2, 3).repeat(1, 3, 1, 1, 1) # B, C, T, H, W |
| 1766 | |
| 1767 | if ref_images is not None: |
| 1768 | ref_images = ref_images.clone()[..., :3] |
| 1769 | # Create padded image |
| 1770 | if ref_images.shape[0] > 1: |
| 1771 | ref_images = torch.cat([ref_images[i] for i in range(ref_images.shape[0])], dim=1).unsqueeze(0) |
| 1772 | |
| 1773 | B, H, W, C = ref_images.shape |
| 1774 | current_aspect = W / H |
| 1775 | target_aspect = width / height |
| 1776 | if current_aspect > target_aspect: |
| 1777 | # Image is wider than target, pad height |
| 1778 | new_h = int(W / target_aspect) |
| 1779 | pad_h = (new_h - H) // 2 |
| 1780 | padded = torch.ones(ref_images.shape[0], new_h, W, ref_images.shape[3], device=ref_images.device, dtype=ref_images.dtype) |
| 1781 | padded[:, pad_h:pad_h+H, :, :] = ref_images |
| 1782 | ref_images = padded |
| 1783 | elif current_aspect < target_aspect: |
| 1784 | # Image is taller than target, pad width |
| 1785 | new_w = int(H * target_aspect) |
| 1786 | pad_w = (new_w - W) // 2 |
| 1787 | padded = torch.ones(ref_images.shape[0], H, new_w, ref_images.shape[3], device=ref_images.device, dtype=ref_images.dtype) |
| 1788 | padded[:, :, pad_w:pad_w+W, :] = ref_images |
| 1789 | ref_images = padded |
| 1790 | ref_images = common_upscale(ref_images.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1) |
| 1791 | |
| 1792 | ref_images = ref_images.to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3).unsqueeze(0) |
| 1793 | ref_images = ref_images * 2 - 1 |
| 1794 | |
| 1795 | vae = vae.to(device) |
| 1796 | z0 = self.vace_encode_frames(vae, input_frames, ref_images, masks=input_masks, tiled_vae=tiled_vae) |
| 1797 | |
| 1798 | m0 = self.vace_encode_masks(input_masks, ref_images) |
| 1799 | z = self.vace_latent(z0, m0) |
| 1800 | vae.to(offload_device) |
nothing calls this directly
no test coverage detected