(
self,
x,
t,
clip_fea,
context,
seq_len,
y=None,
pose_latents=None,
face_pixel_values=None
)
| 370 | |
| 371 | |
| 372 | def forward( |
| 373 | self, |
| 374 | x, |
| 375 | t, |
| 376 | clip_fea, |
| 377 | context, |
| 378 | seq_len, |
| 379 | y=None, |
| 380 | pose_latents=None, |
| 381 | face_pixel_values=None |
| 382 | ): |
| 383 | # params |
| 384 | device = self.patch_embedding.weight.device |
| 385 | if self.freqs.device != device: |
| 386 | self.freqs = self.freqs.to(device) |
| 387 | |
| 388 | if y is not None: |
| 389 | x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] |
| 390 | |
| 391 | # embeddings |
| 392 | x = [self.patch_embedding(u.unsqueeze(0)) for u in x] |
| 393 | x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values) |
| 394 | |
| 395 | grid_sizes = torch.stack( |
| 396 | [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) |
| 397 | x = [u.flatten(2).transpose(1, 2) for u in x] |
| 398 | seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) |
| 399 | assert seq_lens.max() <= seq_len |
| 400 | x = torch.cat([ |
| 401 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], |
| 402 | dim=1) for u in x |
| 403 | ]) |
| 404 | |
| 405 | # time embeddings |
| 406 | with amp.autocast(dtype=torch.float32): |
| 407 | e = self.time_embedding( |
| 408 | sinusoidal_embedding_1d(self.freq_dim, t).float() |
| 409 | ) |
| 410 | e0 = self.time_projection(e).unflatten(1, (6, self.dim)) |
| 411 | assert e.dtype == torch.float32 and e0.dtype == torch.float32 |
| 412 | |
| 413 | # context |
| 414 | context_lens = None |
| 415 | context = self.text_embedding( |
| 416 | torch.stack([ |
| 417 | torch.cat( |
| 418 | [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) |
| 419 | for u in context |
| 420 | ])) |
| 421 | |
| 422 | if self.use_img_emb: |
| 423 | context_clip = self.img_emb(clip_fea) # bs x 257 x dim |
| 424 | context = torch.concat([context_clip, context], dim=1) |
| 425 | |
| 426 | # arguments |
| 427 | kwargs = dict( |
| 428 | e=e0, |
| 429 | seq_lens=seq_lens, |
nothing calls this directly
no test coverage detected