x: A list of videos each with shape [C, T, H, W]. t: [B]. context: A list of text embeddings each with shape [L, C]. seq_len: A list of video token lens, no need for this model. ref_latents A lis
(
self,
x,
t,
context,
seq_len,
ref_latents,
motion_latents,
cond_states,
audio_input=None,
motion_frames=[17, 5],
add_last_motion=2,
drop_motion_frames=False,
*extra_args,
**extra_kwargs)
| 647 | return hidden_states |
| 648 | |
| 649 | def forward( |
| 650 | self, |
| 651 | x, |
| 652 | t, |
| 653 | context, |
| 654 | seq_len, |
| 655 | ref_latents, |
| 656 | motion_latents, |
| 657 | cond_states, |
| 658 | audio_input=None, |
| 659 | motion_frames=[17, 5], |
| 660 | add_last_motion=2, |
| 661 | drop_motion_frames=False, |
| 662 | *extra_args, |
| 663 | **extra_kwargs): |
| 664 | """ |
| 665 | x: A list of videos each with shape [C, T, H, W]. |
| 666 | t: [B]. |
| 667 | context: A list of text embeddings each with shape [L, C]. |
| 668 | seq_len: A list of video token lens, no need for this model. |
| 669 | ref_latents A list of reference image for each video with shape [C, 1, H, W]. |
| 670 | motion_latents A list of motion frames for each video with shape [C, T_m, H, W]. |
| 671 | cond_states A list of condition frames (i.e. pose) each with shape [C, T, H, W]. |
| 672 | audio_input The input audio embedding [B, num_wav2vec_layer, C_a, T_a]. |
| 673 | motion_frames The number of motion frames and motion latents frames encoded by vae, i.e. [17, 5] |
| 674 | add_last_motion For the motioner, if add_last_motion > 0, it means that the most recent frame (i.e., the last frame) will be added. |
| 675 | For frame packing, the behavior depends on the value of add_last_motion: |
| 676 | add_last_motion = 0: Only the farthest part of the latent (i.e., clean_latents_4x) is included. |
| 677 | add_last_motion = 1: Both clean_latents_2x and clean_latents_4x are included. |
| 678 | add_last_motion = 2: All motion-related latents are used. |
| 679 | drop_motion_frames Bool, whether drop the motion frames info |
| 680 | """ |
| 681 | add_last_motion = self.add_last_motion * add_last_motion |
| 682 | audio_input = torch.cat([ |
| 683 | audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input |
| 684 | ], |
| 685 | dim=-1) |
| 686 | audio_emb_res = self.casual_audio_encoder(audio_input) |
| 687 | if self.enbale_adain: |
| 688 | audio_emb_global, audio_emb = audio_emb_res |
| 689 | self.audio_emb_global = audio_emb_global[:, |
| 690 | motion_frames[1]:].clone() |
| 691 | else: |
| 692 | audio_emb = audio_emb_res |
| 693 | self.merged_audio_emb = audio_emb[:, motion_frames[1]:, :] |
| 694 | |
| 695 | device = self.patch_embedding.weight.device |
| 696 | |
| 697 | # embeddings |
| 698 | x = [self.patch_embedding(u.unsqueeze(0)) for u in x] |
| 699 | # cond states |
| 700 | cond = [self.cond_encoder(c.unsqueeze(0)) for c in cond_states] |
| 701 | x = [x_ + pose for x_, pose in zip(x, cond)] |
| 702 | |
| 703 | grid_sizes = torch.stack( |
| 704 | [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) |
| 705 | x = [u.flatten(2).transpose(1, 2) for u in x] |
| 706 | seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) |
nothing calls this directly
no test coverage detected