MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / DecoderLayer

Class DecoderLayer

tensorrt_llm/models/enc_dec/model.py:345–596  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

343
344
345class DecoderLayer(Module):
346
347 def __init__(self,
348 *,
349 local_layer_idx,
350 hidden_size,
351 ffn_hidden_size,
352 num_attention_heads,
353 num_kv_heads,
354 head_size,
355 max_position_embeddings=None,
356 q_scaling=1.0,
357 has_attention_qkvo_bias=False,
358 has_mlp_bias=False,
359 layernorm_position=LayerNormPositionType.pre_layernorm,
360 layernorm_type=LayerNormType.LayerNorm,
361 layernorm_eps=1e-5,
362 hidden_act="relu",
363 mlp_type=MLPType.MLP,
364 mapping=Mapping(),
365 dtype=None,
366 residual_scaling=1.0,
367 relative_attention=False,
368 max_distance=0,
369 num_buckets=0,
370 fp16_clamping=False,
371 skip_cross_kv=False,
372 use_implicit_relative_attention=False,
373 quant_mode=QuantMode(0),
374 language_adapter_config: LanguageAdapterConfig = None):
375 super().__init__()
376
377 # e.g. BART regular, T5 RMS
378 self.layernorm_type = layernorm_type
379 ln_type = layernorm_map[layernorm_type]
380
381 # e.g. BART post, T5 pre
382 self.layernorm_position = layernorm_position
383
384 # e.g. BART q_scaling = 1.f, T5 q_scaling = 1.f/sqrt(head_size)
385 self.self_attention = Attention(
386 local_layer_idx=local_layer_idx,
387 hidden_size=hidden_size,
388 num_attention_heads=num_attention_heads,
389 attention_head_size=head_size,
390 num_kv_heads=num_kv_heads,
391 max_position_embeddings=max_position_embeddings,
392 q_scaling=q_scaling,
393 bias=has_attention_qkvo_bias,
394 attention_mask_type=AttentionMaskType.causal,
395 tp_group=mapping.tp_group,
396 tp_size=mapping.tp_size,
397 tp_rank=mapping.tp_rank,
398 dtype=dtype,
399 cross_attention=False,
400 relative_attention=relative_attention,
401 max_distance=max_distance if use_implicit_relative_attention else 0,
402 num_buckets=num_buckets,

Callers 1

__init__Method · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected