| 181 | class EncoderLayer(Module): |
| 182 | |
| 183 | def __init__(self, |
| 184 | hidden_size, |
| 185 | ffn_hidden_size, |
| 186 | num_attention_heads, |
| 187 | num_kv_heads, |
| 188 | head_size, |
| 189 | max_position_embeddings=None, |
| 190 | q_scaling=1.0, |
| 191 | has_attention_qkvo_bias=False, |
| 192 | has_mlp_bias=False, |
| 193 | layernorm_position=LayerNormPositionType.pre_layernorm, |
| 194 | layernorm_type=LayerNormType.LayerNorm, |
| 195 | layernorm_eps=1e-5, |
| 196 | hidden_act="relu", |
| 197 | mlp_type=MLPType.MLP, |
| 198 | mapping=Mapping(), |
| 199 | dtype=None, |
| 200 | residual_scaling=1.0, |
| 201 | relative_attention=False, |
| 202 | max_distance=0, |
| 203 | num_buckets=0, |
| 204 | fp16_clamping=False, |
| 205 | quant_mode=QuantMode(0), |
| 206 | language_adapter_config: LanguageAdapterConfig = None): |
| 207 | super().__init__() |
| 208 | |
| 209 | # e.g. BART regular, T5 RMS |
| 210 | self.layernorm_type = layernorm_type |
| 211 | ln_type = layernorm_map[layernorm_type] |
| 212 | |
| 213 | # e.g. BART post, T5 pre |
| 214 | self.layernorm_position = layernorm_position |
| 215 | |
| 216 | # e.g. BART q_scaling = 1.f, T5 q_scaling = 1.f/sqrt(head_size) |
| 217 | self.attention = BertAttention( |
| 218 | hidden_size, |
| 219 | num_attention_heads, |
| 220 | attention_head_size=head_size, |
| 221 | num_kv_heads=num_kv_heads, |
| 222 | max_position_embeddings=max_position_embeddings, |
| 223 | q_scaling=q_scaling, |
| 224 | bias=has_attention_qkvo_bias, |
| 225 | tp_group=mapping.tp_group, |
| 226 | tp_size=mapping.tp_size, |
| 227 | tp_rank=mapping.tp_rank, |
| 228 | dtype=dtype, |
| 229 | relative_attention=relative_attention, |
| 230 | max_distance=max_distance, |
| 231 | num_buckets=num_buckets, |
| 232 | quant_mode=quant_mode) |
| 233 | |
| 234 | self.attention_layernorm = ln_type(normalized_shape=hidden_size, |
| 235 | eps=layernorm_eps, |
| 236 | dtype=dtype) |
| 237 | |
| 238 | # T5/BART MLP, Flan-T5 GatedMLP |
| 239 | self.mlp_type = mlp_type |
| 240 | mlp_f = mlp_map[mlp_type] |