MCPcopy
hub / github.com/MeiGen-AI/InfiniteTalk / forward

Method forward

wan/modules/vace_model.py:155–250  ·  view source on GitHub ↗

r""" Forward pass through the diffusion model Args: x (List[Tensor]): List of input video tensors, each with shape [C_in, F, H, W] t (Tensor): Diffusion timesteps tensor of shape [B] context (List[Tensor]):

(
        self,
        x,
        t,
        vace_context,
        context,
        seq_len,
        vace_context_scale=1.0,
        clip_fea=None,
        y=None,
    )

Source from the content-addressed store, hash-verified

153 return hints
154
155 def forward(
156 self,
157 x,
158 t,
159 vace_context,
160 context,
161 seq_len,
162 vace_context_scale=1.0,
163 clip_fea=None,
164 y=None,
165 ):
166 r"""
167 Forward pass through the diffusion model
168
169 Args:
170 x (List[Tensor]):
171 List of input video tensors, each with shape [C_in, F, H, W]
172 t (Tensor):
173 Diffusion timesteps tensor of shape [B]
174 context (List[Tensor]):
175 List of text embeddings each with shape [L, C]
176 seq_len (`int`):
177 Maximum sequence length for positional encoding
178 clip_fea (Tensor, *optional*):
179 CLIP image features for image-to-video mode
180 y (List[Tensor], *optional*):
181 Conditional video inputs for image-to-video mode, same shape as x
182
183 Returns:
184 List[Tensor]:
185 List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
186 """
187 # if self.model_type == 'i2v':
188 # assert clip_fea is not None and y is not None
189 # params
190 device = self.patch_embedding.weight.device
191 if self.freqs.device != device:
192 self.freqs = self.freqs.to(device)
193
194 # if y is not None:
195 # x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
196
197 # embeddings
198 x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
199 grid_sizes = torch.stack(
200 [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
201 x = [u.flatten(2).transpose(1, 2) for u in x]
202 seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
203 assert seq_lens.max() <= seq_len
204 x = torch.cat([
205 torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
206 dim=1) for u in x
207 ])
208
209 # time embeddings
210 with amp.autocast(dtype=torch.float32):
211 e = self.time_embedding(
212 sinusoidal_embedding_1d(self.freq_dim, t).float())

Callers 2

forwardMethod · 0.45
forwardMethod · 0.45

Calls 3

forward_vaceMethod · 0.95
sinusoidal_embedding_1dFunction · 0.70
unpatchifyMethod · 0.45

Tested by

no test coverage detected