| 343 | |
| 344 | |
| 345 | class DecoderLayer(Module): |
| 346 | |
| 347 | def __init__(self, |
| 348 | *, |
| 349 | local_layer_idx, |
| 350 | hidden_size, |
| 351 | ffn_hidden_size, |
| 352 | num_attention_heads, |
| 353 | num_kv_heads, |
| 354 | head_size, |
| 355 | max_position_embeddings=None, |
| 356 | q_scaling=1.0, |
| 357 | has_attention_qkvo_bias=False, |
| 358 | has_mlp_bias=False, |
| 359 | layernorm_position=LayerNormPositionType.pre_layernorm, |
| 360 | layernorm_type=LayerNormType.LayerNorm, |
| 361 | layernorm_eps=1e-5, |
| 362 | hidden_act="relu", |
| 363 | mlp_type=MLPType.MLP, |
| 364 | mapping=Mapping(), |
| 365 | dtype=None, |
| 366 | residual_scaling=1.0, |
| 367 | relative_attention=False, |
| 368 | max_distance=0, |
| 369 | num_buckets=0, |
| 370 | fp16_clamping=False, |
| 371 | skip_cross_kv=False, |
| 372 | use_implicit_relative_attention=False, |
| 373 | quant_mode=QuantMode(0), |
| 374 | language_adapter_config: LanguageAdapterConfig = None): |
| 375 | super().__init__() |
| 376 | |
| 377 | # e.g. BART regular, T5 RMS |
| 378 | self.layernorm_type = layernorm_type |
| 379 | ln_type = layernorm_map[layernorm_type] |
| 380 | |
| 381 | # e.g. BART post, T5 pre |
| 382 | self.layernorm_position = layernorm_position |
| 383 | |
| 384 | # e.g. BART q_scaling = 1.f, T5 q_scaling = 1.f/sqrt(head_size) |
| 385 | self.self_attention = Attention( |
| 386 | local_layer_idx=local_layer_idx, |
| 387 | hidden_size=hidden_size, |
| 388 | num_attention_heads=num_attention_heads, |
| 389 | attention_head_size=head_size, |
| 390 | num_kv_heads=num_kv_heads, |
| 391 | max_position_embeddings=max_position_embeddings, |
| 392 | q_scaling=q_scaling, |
| 393 | bias=has_attention_qkvo_bias, |
| 394 | attention_mask_type=AttentionMaskType.causal, |
| 395 | tp_group=mapping.tp_group, |
| 396 | tp_size=mapping.tp_size, |
| 397 | tp_rank=mapping.tp_rank, |
| 398 | dtype=dtype, |
| 399 | cross_attention=False, |
| 400 | relative_attention=relative_attention, |
| 401 | max_distance=max_distance if use_implicit_relative_attention else 0, |
| 402 | num_buckets=num_buckets, |