Add an operation that performs the multi-head attention in BERT. The multi-head attention (MHA) is the sequence of a batched matmul, a softmax and a batched matmul as described in https://arxiv.org/abs/1706.03762. That function adds an operation that performs those computations
(tensor: Tensor,
input_lengths: Tensor,
num_heads: int,
head_size: int,
q_scaling: float,
relative_attention: bool = False,
relative_attention_bias: Tensor = None,
max_distance: int = 0,
max_input_length: Tensor = None,
sage_attn: bool = False,
sage_attn_q_block_size: int = 0,
sage_attn_k_block_size: int = 0,
sage_attn_v_block_size: int = 0,
cp_group: list[int] = None,
cp_size: int = 1,
cp_rank: int = 0)
| 4485 | |
| 4486 | |
| 4487 | def bert_attention(tensor: Tensor, |
| 4488 | input_lengths: Tensor, |
| 4489 | num_heads: int, |
| 4490 | head_size: int, |
| 4491 | q_scaling: float, |
| 4492 | relative_attention: bool = False, |
| 4493 | relative_attention_bias: Tensor = None, |
| 4494 | max_distance: int = 0, |
| 4495 | max_input_length: Tensor = None, |
| 4496 | sage_attn: bool = False, |
| 4497 | sage_attn_q_block_size: int = 0, |
| 4498 | sage_attn_k_block_size: int = 0, |
| 4499 | sage_attn_v_block_size: int = 0, |
| 4500 | cp_group: list[int] = None, |
| 4501 | cp_size: int = 1, |
| 4502 | cp_rank: int = 0) -> Tuple[Tensor]: |
| 4503 | ''' |
| 4504 | Add an operation that performs the multi-head attention in BERT. |
| 4505 | |
| 4506 | The multi-head attention (MHA) is the sequence of a batched matmul, a |
| 4507 | softmax and a batched matmul as described in |
| 4508 | https://arxiv.org/abs/1706.03762. That function adds an operation that |
| 4509 | performs those computations using a single GPU kernel. |
| 4510 | |
| 4511 | The input tensor contains the Q, K and V elements. It is a 2D tensor and |
| 4512 | its shape is '[sum_of_tokens, 3*hidden_dim]' where the 'sum_of_tokens' is |
| 4513 | the sum of the sequence lengths in the batch. |
| 4514 | |
| 4515 | In MHA, the output of the Q*K^T product is scaled by a constant value that |
| 4516 | is computed as: |
| 4517 | |
| 4518 | 1.f / (q_scaling * sqrt(head_size)). |
| 4519 | |
| 4520 | That 'q_scaling' constant is the last argument of that function. |
| 4521 | |
| 4522 | That layer is implemented using a plugin (see bertAttentionPlugin). |
| 4523 | |
| 4524 | Parameters: |
| 4525 | tensor : Tensor |
| 4526 | The QKV input tensor. |
| 4527 | |
| 4528 | input_lengths : Tensor |
| 4529 | The length of each sequence. It is a 1D tensor of size 'batch_size'. |
| 4530 | |
| 4531 | num_heads : int |
| 4532 | The number of heads. |
| 4533 | |
| 4534 | head_size : int |
| 4535 | The size of each head. |
| 4536 | |
| 4537 | q_scaling : float |
| 4538 | The factor to compute the scaling factor to scale the output of the |
| 4539 | 'Q*K^T' product. |
| 4540 | |
| 4541 | relative_attention: bool = False |
| 4542 | If enable relative attention. |
| 4543 | |
| 4544 | relative_attention_bias: Tensor = None |
no test coverage detected