q: [B, Lq, Nq, C1]. k: [B, Lk, Nk, C1]. v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. q_lens: [B]. k_lens: [B]. dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying
(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
version=None,
)
| 22 | |
| 23 | |
| 24 | def flash_attention( |
| 25 | q, |
| 26 | k, |
| 27 | v, |
| 28 | q_lens=None, |
| 29 | k_lens=None, |
| 30 | dropout_p=0., |
| 31 | softmax_scale=None, |
| 32 | q_scale=None, |
| 33 | causal=False, |
| 34 | window_size=(-1, -1), |
| 35 | deterministic=False, |
| 36 | dtype=torch.bfloat16, |
| 37 | version=None, |
| 38 | ): |
| 39 | """ |
| 40 | q: [B, Lq, Nq, C1]. |
| 41 | k: [B, Lk, Nk, C1]. |
| 42 | v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. |
| 43 | q_lens: [B]. |
| 44 | k_lens: [B]. |
| 45 | dropout_p: float. Dropout probability. |
| 46 | softmax_scale: float. The scaling of QK^T before applying softmax. |
| 47 | causal: bool. Whether to apply causal attention mask. |
| 48 | window_size: (left right). If not (-1, -1), apply sliding window local attention. |
| 49 | deterministic: bool. If True, slightly slower and uses more memory. |
| 50 | dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. |
| 51 | """ |
| 52 | half_dtypes = (torch.float16, torch.bfloat16) |
| 53 | assert dtype in half_dtypes |
| 54 | assert q.device.type == 'cuda' and q.size(-1) <= 256 |
| 55 | |
| 56 | # params |
| 57 | b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype |
| 58 | |
| 59 | def half(x): |
| 60 | return x if x.dtype in half_dtypes else x.to(dtype) |
| 61 | |
| 62 | # preprocess query |
| 63 | if q_lens is None: |
| 64 | q = half(q.flatten(0, 1)) |
| 65 | q_lens = torch.tensor( |
| 66 | [lq] * b, dtype=torch.int32).to( |
| 67 | device=q.device, non_blocking=True) |
| 68 | else: |
| 69 | q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) |
| 70 | |
| 71 | # preprocess key, value |
| 72 | if k_lens is None: |
| 73 | k = half(k.flatten(0, 1)) |
| 74 | v = half(v.flatten(0, 1)) |
| 75 | k_lens = torch.tensor( |
| 76 | [lk] * b, dtype=torch.int32).to( |
| 77 | device=k.device, non_blocking=True) |
| 78 | else: |
| 79 | k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) |
| 80 | v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) |
| 81 |