MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / bert_attention

Function bert_attention

tensorrt_llm/functional.py:4487–4676  ·  view source on GitHub ↗

Add an operation that performs the multi-head attention in BERT. The multi-head attention (MHA) is the sequence of a batched matmul, a softmax and a batched matmul as described in https://arxiv.org/abs/1706.03762. That function adds an operation that performs those computations

(tensor: Tensor,
                   input_lengths: Tensor,
                   num_heads: int,
                   head_size: int,
                   q_scaling: float,
                   relative_attention: bool = False,
                   relative_attention_bias: Tensor = None,
                   max_distance: int = 0,
                   max_input_length: Tensor = None,
                   sage_attn: bool = False,
                   sage_attn_q_block_size: int = 0,
                   sage_attn_k_block_size: int = 0,
                   sage_attn_v_block_size: int = 0,
                   cp_group: list[int] = None,
                   cp_size: int = 1,
                   cp_rank: int = 0)

Source from the content-addressed store, hash-verified

4485
4486
4487def bert_attention(tensor: Tensor,
4488 input_lengths: Tensor,
4489 num_heads: int,
4490 head_size: int,
4491 q_scaling: float,
4492 relative_attention: bool = False,
4493 relative_attention_bias: Tensor = None,
4494 max_distance: int = 0,
4495 max_input_length: Tensor = None,
4496 sage_attn: bool = False,
4497 sage_attn_q_block_size: int = 0,
4498 sage_attn_k_block_size: int = 0,
4499 sage_attn_v_block_size: int = 0,
4500 cp_group: list[int] = None,
4501 cp_size: int = 1,
4502 cp_rank: int = 0) -> Tuple[Tensor]:
4503 '''
4504 Add an operation that performs the multi-head attention in BERT.
4505
4506 The multi-head attention (MHA) is the sequence of a batched matmul, a
4507 softmax and a batched matmul as described in
4508 https://arxiv.org/abs/1706.03762. That function adds an operation that
4509 performs those computations using a single GPU kernel.
4510
4511 The input tensor contains the Q, K and V elements. It is a 2D tensor and
4512 its shape is '[sum_of_tokens, 3*hidden_dim]' where the 'sum_of_tokens' is
4513 the sum of the sequence lengths in the batch.
4514
4515 In MHA, the output of the Q*K^T product is scaled by a constant value that
4516 is computed as:
4517
4518 1.f / (q_scaling * sqrt(head_size)).
4519
4520 That 'q_scaling' constant is the last argument of that function.
4521
4522 That layer is implemented using a plugin (see bertAttentionPlugin).
4523
4524 Parameters:
4525 tensor : Tensor
4526 The QKV input tensor.
4527
4528 input_lengths : Tensor
4529 The length of each sequence. It is a 1D tensor of size 'batch_size'.
4530
4531 num_heads : int
4532 The number of heads.
4533
4534 head_size : int
4535 The size of each head.
4536
4537 q_scaling : float
4538 The factor to compute the scaling factor to scale the output of the
4539 'Q*K^T' product.
4540
4541 relative_attention: bool = False
4542 If enable relative attention.
4543
4544 relative_attention_bias: Tensor = None

Callers 3

forwardMethod · 0.85
joint_attn_forwardMethod · 0.85
forwardMethod · 0.85

Calls 12

default_netFunction · 0.85
str_dtype_to_trtFunction · 0.85
chunkFunction · 0.85
concatFunction · 0.85
default_trtnetFunction · 0.85
_add_plugin_infoFunction · 0.85
_create_tensorFunction · 0.85
int8Method · 0.80
create_pluginMethod · 0.80
shapeFunction · 0.70
viewMethod · 0.45
get_outputMethod · 0.45

Tested by

no test coverage detected