| 42 | |
| 43 | |
| 44 | class VanillaTemporalModule(nn.Module): |
| 45 | def __init__( |
| 46 | self, |
| 47 | in_channels, |
| 48 | num_attention_heads=8, |
| 49 | num_transformer_block=2, |
| 50 | attention_block_types=("Temporal_Self", "Temporal_Self"), |
| 51 | cross_frame_attention_mode=None, |
| 52 | temporal_position_encoding=False, |
| 53 | temporal_position_encoding_max_len=24, |
| 54 | temporal_attention_dim_div=1, |
| 55 | zero_initialize=True, |
| 56 | ): |
| 57 | super().__init__() |
| 58 | |
| 59 | self.temporal_transformer = TemporalTransformer3DModel( |
| 60 | in_channels=in_channels, |
| 61 | num_attention_heads=num_attention_heads, |
| 62 | attention_head_dim=in_channels |
| 63 | // num_attention_heads |
| 64 | // temporal_attention_dim_div, |
| 65 | num_layers=num_transformer_block, |
| 66 | attention_block_types=attention_block_types, |
| 67 | cross_frame_attention_mode=cross_frame_attention_mode, |
| 68 | temporal_position_encoding=temporal_position_encoding, |
| 69 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, |
| 70 | ) |
| 71 | |
| 72 | if zero_initialize: |
| 73 | self.temporal_transformer.proj_out = zero_module( |
| 74 | self.temporal_transformer.proj_out |
| 75 | ) |
| 76 | |
| 77 | def forward( |
| 78 | self, |
| 79 | input_tensor, |
| 80 | temb, |
| 81 | encoder_hidden_states, |
| 82 | attention_mask=None, |
| 83 | anchor_frame_idx=None, |
| 84 | ): |
| 85 | hidden_states = input_tensor |
| 86 | hidden_states = self.temporal_transformer( |
| 87 | hidden_states, encoder_hidden_states, attention_mask |
| 88 | ) |
| 89 | |
| 90 | output = hidden_states |
| 91 | return output |
| 92 | |
| 93 | |
| 94 | class TemporalTransformer3DModel(nn.Module): |