MCPcopy
hub / github.com/Robbyant/lingbot-world / forward

Method forward

wan/modules/s2v/model_s2v.py:649–856  ·  view source on GitHub ↗

x: A list of videos each with shape [C, T, H, W]. t: [B]. context: A list of text embeddings each with shape [L, C]. seq_len: A list of video token lens, no need for this model. ref_latents A lis

(
            self,
            x,
            t,
            context,
            seq_len,
            ref_latents,
            motion_latents,
            cond_states,
            audio_input=None,
            motion_frames=[17, 5],
            add_last_motion=2,
            drop_motion_frames=False,
            *extra_args,
            **extra_kwargs)

Source from the content-addressed store, hash-verified

647 return hidden_states
648
649 def forward(
650 self,
651 x,
652 t,
653 context,
654 seq_len,
655 ref_latents,
656 motion_latents,
657 cond_states,
658 audio_input=None,
659 motion_frames=[17, 5],
660 add_last_motion=2,
661 drop_motion_frames=False,
662 *extra_args,
663 **extra_kwargs):
664 """
665 x: A list of videos each with shape [C, T, H, W].
666 t: [B].
667 context: A list of text embeddings each with shape [L, C].
668 seq_len: A list of video token lens, no need for this model.
669 ref_latents A list of reference image for each video with shape [C, 1, H, W].
670 motion_latents A list of motion frames for each video with shape [C, T_m, H, W].
671 cond_states A list of condition frames (i.e. pose) each with shape [C, T, H, W].
672 audio_input The input audio embedding [B, num_wav2vec_layer, C_a, T_a].
673 motion_frames The number of motion frames and motion latents frames encoded by vae, i.e. [17, 5]
674 add_last_motion For the motioner, if add_last_motion > 0, it means that the most recent frame (i.e., the last frame) will be added.
675 For frame packing, the behavior depends on the value of add_last_motion:
676 add_last_motion = 0: Only the farthest part of the latent (i.e., clean_latents_4x) is included.
677 add_last_motion = 1: Both clean_latents_2x and clean_latents_4x are included.
678 add_last_motion = 2: All motion-related latents are used.
679 drop_motion_frames Bool, whether drop the motion frames info
680 """
681 add_last_motion = self.add_last_motion * add_last_motion
682 audio_input = torch.cat([
683 audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input
684 ],
685 dim=-1)
686 audio_emb_res = self.casual_audio_encoder(audio_input)
687 if self.enbale_adain:
688 audio_emb_global, audio_emb = audio_emb_res
689 self.audio_emb_global = audio_emb_global[:,
690 motion_frames[1]:].clone()
691 else:
692 audio_emb = audio_emb_res
693 self.merged_audio_emb = audio_emb[:, motion_frames[1]:, :]
694
695 device = self.patch_embedding.weight.device
696
697 # embeddings
698 x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
699 # cond states
700 cond = [self.cond_encoder(c.unsqueeze(0)) for c in cond_states]
701 x = [x_ + pose for x_, pose in zip(x, cond)]
702
703 grid_sizes = torch.stack(
704 [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
705 x = [u.flatten(2).transpose(1, 2) for u in x]
706 seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)

Callers

nothing calls this directly

Calls 10

inject_motionMethod · 0.95
unpatchifyMethod · 0.95
rope_precomputeFunction · 0.85
get_rankFunction · 0.85
get_world_sizeFunction · 0.85
gather_forwardFunction · 0.85
sizeMethod · 0.80
toMethod · 0.80
sinusoidal_embedding_1dFunction · 0.70

Tested by

no test coverage detected