MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / forward

Method forward

tensorrt_llm/layers/attention.py:1720–1846  ·  view source on GitHub ↗
(self,
                hidden_states: Tensor,
                attention_mask=None,
                input_lengths=None,
                max_input_length=None,
                lora_layer_params=None)

Source from the content-addressed store, hash-verified

1718 dtype=dtype)
1719
1720 def forward(self,
1721 hidden_states: Tensor,
1722 attention_mask=None,
1723 input_lengths=None,
1724 max_input_length=None,
1725 lora_layer_params=None):
1726 assert isinstance(hidden_states, Tensor)
1727
1728 qkv_lora_params = None
1729 if lora_layer_params is not None:
1730 qkv_lora_params = lora_layer_params.get_runtime_params(
1731 0, "attn_qkv")
1732
1733 qkv = self.qkv(hidden_states, qkv_lora_params)
1734
1735 if default_net().plugin_config.remove_input_padding:
1736 assert qkv.ndim() == 2
1737
1738 if default_net(
1739 ).plugin_config.lora_plugin and qkv_lora_params is None and lora_layer_params is not None:
1740 q_lora_params = lora_layer_params.get_runtime_params(0, "attn_q")
1741 k_lora_params = lora_layer_params.get_runtime_params(0, "attn_k")
1742 v_lora_params = lora_layer_params.get_runtime_params(0, "attn_v")
1743
1744 assert (q_lora_params is not None and k_lora_params is not None and v_lora_params is not None) or \
1745 (q_lora_params is None and k_lora_params is None and v_lora_params is None), "q_lora_params, k_lora_params and v_lora_params should be all enabled or all disabled at the same time."
1746
1747 if q_lora_params is not None and k_lora_params is not None and v_lora_params is not None:
1748 qkv_lora_params = LoraRuntimeParams(
1749 lora_ranks=[
1750 q_lora_params.lora_ranks[0],
1751 k_lora_params.lora_ranks[0],
1752 v_lora_params.lora_ranks[0],
1753 ],
1754 lora_weights_pointers=[
1755 q_lora_params.lora_weights_pointers[0],
1756 k_lora_params.lora_weights_pointers[0],
1757 v_lora_params.lora_weights_pointers[0],
1758 ],
1759 host_request_types=q_lora_params.host_request_types,
1760 host_context_lengths=q_lora_params.host_context_lengths)
1761
1762 q_lora, k_lora, v_lora = self.qkv_lora(hidden_states,
1763 qkv_lora_params)
1764 qkv_lora = concat([q_lora, k_lora, v_lora],
1765 dim=q_lora.rank() - 1)
1766 qkv = qkv + qkv_lora
1767
1768 if default_net().plugin_config.bert_attention_plugin:
1769 # TRT plugin mode
1770 assert input_lengths is not None
1771 context = bert_attention(
1772 qkv,
1773 input_lengths,
1774 self.num_attention_heads,
1775 self.attention_head_size,
1776 q_scaling=self.q_scaling,
1777 relative_attention=self.relative_attention,

Callers

nothing calls this directly

Calls 15

default_netFunction · 0.85
LoraRuntimeParamsClass · 0.85
concatFunction · 0.85
bert_attentionFunction · 0.85
matmulFunction · 0.85
compute_relative_biasFunction · 0.85
expand_maskFunction · 0.85
castFunction · 0.85
softmaxFunction · 0.85
get_runtime_paramsMethod · 0.80
transposeMethod · 0.80
splitFunction · 0.50

Tested by

no test coverage detected