Expand an attention mask. That function adds the sequence of operations to expand from a tensor of shape '[batch_size, src_seq_len]' to a tensor of shape '[batch_size, 1, tgt_seq_len, src_seq_len]'. It can be used to create the mask applied to the Q*K^T product before the softm
(mask: Tensor, tgt_len: Optional[Tensor] = None)
| 6465 | |
| 6466 | |
| 6467 | def expand_mask(mask: Tensor, tgt_len: Optional[Tensor] = None) -> Tensor: |
| 6468 | ''' |
| 6469 | Expand an attention mask. |
| 6470 | |
| 6471 | That function adds the sequence of operations to expand from a tensor of |
| 6472 | shape '[batch_size, src_seq_len]' to a tensor of shape |
| 6473 | '[batch_size, 1, tgt_seq_len, src_seq_len]'. It can be used to create the |
| 6474 | mask applied to the Q*K^T product before the softmax operation in the |
| 6475 | multi-head attention block. |
| 6476 | |
| 6477 | Parameters: |
| 6478 | mask : Tensor |
| 6479 | The input mask |
| 6480 | |
| 6481 | tgt_len : Optional[Tensor] |
| 6482 | The dimension of the 3rd dimension in the output tensor. If None, |
| 6483 | the 2nd dimension of the input is used. |
| 6484 | |
| 6485 | Returns: |
| 6486 | The tensor created by that sequence of operations. |
| 6487 | ''' |
| 6488 | bsz = shape(mask, 0) |
| 6489 | src_len = shape(mask, 1) |
| 6490 | tgt_len = tgt_len if tgt_len is not None else src_len |
| 6491 | |
| 6492 | mask = mask.view(concat([bsz, 1, 1, src_len])) |
| 6493 | |
| 6494 | mask = expand(mask, concat([bsz, 1, tgt_len, src_len])) |
| 6495 | mask = where(mask == 0, float('-inf'), 0.0) |
| 6496 | return mask |
| 6497 | |
| 6498 | |
| 6499 | def gather_last_token_logits(hidden_states: Tensor, last_token_ids: Tensor, |