r""" Forward pass through the diffusion model Args: x (List[Tensor]): List of input video tensors, each with shape [C_in, F, H, W] t (Tensor): Diffusion timesteps tensor of shape [B] context (List[Tensor]):
(
self,
x,
t,
vace_context,
context,
seq_len,
vace_context_scale=1.0,
clip_fea=None,
y=None,
)
| 153 | return hints |
| 154 | |
| 155 | def forward( |
| 156 | self, |
| 157 | x, |
| 158 | t, |
| 159 | vace_context, |
| 160 | context, |
| 161 | seq_len, |
| 162 | vace_context_scale=1.0, |
| 163 | clip_fea=None, |
| 164 | y=None, |
| 165 | ): |
| 166 | r""" |
| 167 | Forward pass through the diffusion model |
| 168 | |
| 169 | Args: |
| 170 | x (List[Tensor]): |
| 171 | List of input video tensors, each with shape [C_in, F, H, W] |
| 172 | t (Tensor): |
| 173 | Diffusion timesteps tensor of shape [B] |
| 174 | context (List[Tensor]): |
| 175 | List of text embeddings each with shape [L, C] |
| 176 | seq_len (`int`): |
| 177 | Maximum sequence length for positional encoding |
| 178 | clip_fea (Tensor, *optional*): |
| 179 | CLIP image features for image-to-video mode |
| 180 | y (List[Tensor], *optional*): |
| 181 | Conditional video inputs for image-to-video mode, same shape as x |
| 182 | |
| 183 | Returns: |
| 184 | List[Tensor]: |
| 185 | List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] |
| 186 | """ |
| 187 | # if self.model_type == 'i2v': |
| 188 | # assert clip_fea is not None and y is not None |
| 189 | # params |
| 190 | device = self.patch_embedding.weight.device |
| 191 | if self.freqs.device != device: |
| 192 | self.freqs = self.freqs.to(device) |
| 193 | |
| 194 | # if y is not None: |
| 195 | # x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] |
| 196 | |
| 197 | # embeddings |
| 198 | x = [self.patch_embedding(u.unsqueeze(0)) for u in x] |
| 199 | grid_sizes = torch.stack( |
| 200 | [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) |
| 201 | x = [u.flatten(2).transpose(1, 2) for u in x] |
| 202 | seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) |
| 203 | assert seq_lens.max() <= seq_len |
| 204 | x = torch.cat([ |
| 205 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], |
| 206 | dim=1) for u in x |
| 207 | ]) |
| 208 | |
| 209 | # time embeddings |
| 210 | with amp.autocast(dtype=torch.float32): |
| 211 | e = self.time_embedding( |
| 212 | sinusoidal_embedding_1d(self.freq_dim, t).float()) |
no test coverage detected