MCPcopy
hub / github.com/fudan-generative-vision/champ / VanillaTemporalModule

Class VanillaTemporalModule

models/motion_module.py:44–91  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

42
43
44class VanillaTemporalModule(nn.Module):
45 def __init__(
46 self,
47 in_channels,
48 num_attention_heads=8,
49 num_transformer_block=2,
50 attention_block_types=("Temporal_Self", "Temporal_Self"),
51 cross_frame_attention_mode=None,
52 temporal_position_encoding=False,
53 temporal_position_encoding_max_len=24,
54 temporal_attention_dim_div=1,
55 zero_initialize=True,
56 ):
57 super().__init__()
58
59 self.temporal_transformer = TemporalTransformer3DModel(
60 in_channels=in_channels,
61 num_attention_heads=num_attention_heads,
62 attention_head_dim=in_channels
63 // num_attention_heads
64 // temporal_attention_dim_div,
65 num_layers=num_transformer_block,
66 attention_block_types=attention_block_types,
67 cross_frame_attention_mode=cross_frame_attention_mode,
68 temporal_position_encoding=temporal_position_encoding,
69 temporal_position_encoding_max_len=temporal_position_encoding_max_len,
70 )
71
72 if zero_initialize:
73 self.temporal_transformer.proj_out = zero_module(
74 self.temporal_transformer.proj_out
75 )
76
77 def forward(
78 self,
79 input_tensor,
80 temb,
81 encoder_hidden_states,
82 attention_mask=None,
83 anchor_frame_idx=None,
84 ):
85 hidden_states = input_tensor
86 hidden_states = self.temporal_transformer(
87 hidden_states, encoder_hidden_states, attention_mask
88 )
89
90 output = hidden_states
91 return output
92
93
94class TemporalTransformer3DModel(nn.Module):

Callers 1

get_motion_moduleFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected