(self,
hidden_states: Tensor,
attention_mask=None,
use_cache=False,
kv_cache_params=None,
attention_params=None)
| 102 | return self.new_decoder_architecture or self.parallel_attn |
| 103 | |
| 104 | def forward(self, |
| 105 | hidden_states: Tensor, |
| 106 | attention_mask=None, |
| 107 | use_cache=False, |
| 108 | kv_cache_params=None, |
| 109 | attention_params=None): |
| 110 | assert isinstance(hidden_states, Tensor) |
| 111 | |
| 112 | residual = hidden_states |
| 113 | |
| 114 | if self.new_decoder_architecture and self.num_ln_in_parallel_attn == 2: |
| 115 | mlp_ln_output = self.mlp_layernorm(hidden_states) |
| 116 | hidden_states = self.input_layernorm(hidden_states) |
| 117 | input_ln_output = hidden_states |
| 118 | attention_output = self.attention(hidden_states, |
| 119 | attention_mask=attention_mask, |
| 120 | use_cache=use_cache, |
| 121 | kv_cache_params=kv_cache_params, |
| 122 | attention_params=attention_params) |
| 123 | |
| 124 | if use_cache: |
| 125 | attention_output, presents = attention_output |
| 126 | |
| 127 | if not self.new_decoder_architecture: |
| 128 | if self.parallel_attn: |
| 129 | hidden_states = input_ln_output |
| 130 | else: |
| 131 | hidden_states = residual + attention_output |
| 132 | residual = hidden_states |
| 133 | hidden_states = self.post_layernorm(hidden_states) |
| 134 | elif self.num_ln_in_parallel_attn == 2: |
| 135 | hidden_states = mlp_ln_output |
| 136 | |
| 137 | if (self.new_decoder_architecture and self.parallel_attn |
| 138 | and self.num_ln_in_parallel_attn == 1): |
| 139 | hidden_states = input_ln_output |
| 140 | |
| 141 | hidden_states = self.mlp(hidden_states) |
| 142 | |
| 143 | if self.is_parallel_attention: |
| 144 | hidden_states = hidden_states + attention_output |
| 145 | if self.config.mapping.tp_size > 1: |
| 146 | hidden_states = allreduce(hidden_states, |
| 147 | self.config.mapping.tp_group) |
| 148 | |
| 149 | hidden_states = residual + hidden_states |
| 150 | if use_cache: |
| 151 | return hidden_states, presents |
| 152 | return hidden_states |
| 153 | |
| 154 | |
| 155 | class FalconModel(Module): |
nothing calls this directly
no test coverage detected