(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
fa_version=None,
)
| 131 | |
| 132 | |
| 133 | def attention( |
| 134 | q, |
| 135 | k, |
| 136 | v, |
| 137 | q_lens=None, |
| 138 | k_lens=None, |
| 139 | dropout_p=0., |
| 140 | softmax_scale=None, |
| 141 | q_scale=None, |
| 142 | causal=False, |
| 143 | window_size=(-1, -1), |
| 144 | deterministic=False, |
| 145 | dtype=torch.bfloat16, |
| 146 | fa_version=None, |
| 147 | ): |
| 148 | if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE: |
| 149 | return flash_attention( |
| 150 | q=q, |
| 151 | k=k, |
| 152 | v=v, |
| 153 | q_lens=q_lens, |
| 154 | k_lens=k_lens, |
| 155 | dropout_p=dropout_p, |
| 156 | softmax_scale=softmax_scale, |
| 157 | q_scale=q_scale, |
| 158 | causal=causal, |
| 159 | window_size=window_size, |
| 160 | deterministic=deterministic, |
| 161 | dtype=dtype, |
| 162 | version=fa_version, |
| 163 | ) |
| 164 | else: |
| 165 | if q_lens is not None or k_lens is not None: |
| 166 | warnings.warn( |
| 167 | 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.' |
| 168 | ) |
| 169 | attn_mask = None |
| 170 | |
| 171 | q = q.transpose(1, 2).to(dtype) |
| 172 | k = k.transpose(1, 2).to(dtype) |
| 173 | v = v.transpose(1, 2).to(dtype) |
| 174 | |
| 175 | out = torch.nn.functional.scaled_dot_product_attention( |
| 176 | q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) |
| 177 | |
| 178 | out = out.transpose(1, 2).contiguous() |
| 179 | return out |
nothing calls this directly
no test coverage detected