(self, width, height, num_frames, force_offload, noise_aug_strength,
start_latent_strength, end_latent_strength, start_image=None, end_image=None, control_embeds=None, fun_or_fl2v_model=False,
temporal_mask=None, extra_latents=None, clip_embeds=None, tiled_vae=False, add_cond_latents=None, vae=None, augment_empty_frames=0.0, empty_frame_pad_image=None)
| 1016 | CATEGORY = "WanVideoWrapper" |
| 1017 | |
| 1018 | def process(self, width, height, num_frames, force_offload, noise_aug_strength, |
| 1019 | start_latent_strength, end_latent_strength, start_image=None, end_image=None, control_embeds=None, fun_or_fl2v_model=False, |
| 1020 | temporal_mask=None, extra_latents=None, clip_embeds=None, tiled_vae=False, add_cond_latents=None, vae=None, augment_empty_frames=0.0, empty_frame_pad_image=None): |
| 1021 | |
| 1022 | if vae is None: |
| 1023 | raise ValueError("VAE is required for image encoding.") |
| 1024 | H = height |
| 1025 | W = width |
| 1026 | |
| 1027 | lat_h = H // vae.upsampling_factor |
| 1028 | lat_w = W // vae.upsampling_factor |
| 1029 | |
| 1030 | num_frames = ((num_frames - 1) // 4) * 4 + 1 |
| 1031 | two_ref_images = start_image is not None and end_image is not None |
| 1032 | |
| 1033 | if start_image is None and end_image is not None: |
| 1034 | fun_or_fl2v_model = True # end image alone only works with this option |
| 1035 | |
| 1036 | base_frames = num_frames + (1 if two_ref_images and not fun_or_fl2v_model else 0) |
| 1037 | if temporal_mask is None: |
| 1038 | mask = torch.zeros(1, base_frames, lat_h, lat_w, device=device, dtype=vae.dtype) |
| 1039 | if start_image is not None: |
| 1040 | mask[:, 0:start_image.shape[0]] = 1 # First frame |
| 1041 | if end_image is not None: |
| 1042 | mask[:, -end_image.shape[0]:] = 1 # End frame if exists |
| 1043 | else: |
| 1044 | mask = common_upscale(temporal_mask.unsqueeze(1).to(device), lat_w, lat_h, "nearest", "disabled").squeeze(1) |
| 1045 | if mask.shape[0] > base_frames: |
| 1046 | mask = mask[:base_frames] |
| 1047 | elif mask.shape[0] < base_frames: |
| 1048 | mask = torch.cat([mask, torch.zeros(base_frames - mask.shape[0], lat_h, lat_w, device=device)]) |
| 1049 | mask = mask.unsqueeze(0).to(device, vae.dtype) |
| 1050 | |
| 1051 | pixel_mask = mask.clone() |
| 1052 | |
| 1053 | # Repeat first frame and optionally end frame |
| 1054 | start_mask_repeated = torch.repeat_interleave(mask[:, 0:1], repeats=4, dim=1) # T, C, H, W |
| 1055 | if end_image is not None and not fun_or_fl2v_model: |
| 1056 | end_mask_repeated = torch.repeat_interleave(mask[:, -1:], repeats=4, dim=1) # T, C, H, W |
| 1057 | mask = torch.cat([start_mask_repeated, mask[:, 1:-1], end_mask_repeated], dim=1) |
| 1058 | else: |
| 1059 | mask = torch.cat([start_mask_repeated, mask[:, 1:]], dim=1) |
| 1060 | |
| 1061 | # Reshape mask into groups of 4 frames |
| 1062 | mask = mask.view(1, mask.shape[1] // 4, 4, lat_h, lat_w) # 1, T, C, H, W |
| 1063 | mask = mask.movedim(1, 2)[0]# C, T, H, W |
| 1064 | |
| 1065 | # Resize and rearrange the input image dimensions |
| 1066 | if start_image is not None: |
| 1067 | start_image = start_image[..., :3] |
| 1068 | if start_image.shape[1] != H or start_image.shape[2] != W: |
| 1069 | resized_start_image = common_upscale(start_image.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(0, 1) |
| 1070 | else: |
| 1071 | resized_start_image = start_image.permute(3, 0, 1, 2) # C, T, H, W |
| 1072 | resized_start_image = resized_start_image * 2 - 1 |
| 1073 | if noise_aug_strength > 0.0: |
| 1074 | resized_start_image = add_noise_to_reference_video(resized_start_image, ratio=noise_aug_strength) |
| 1075 |
nothing calls this directly
no test coverage detected