MCPcopy
hub / github.com/Wan-Video/Wan2.1 / flash_attention

Function flash_attention

wan/modules/attention.py:24–130  ·  view source on GitHub ↗

q: [B, Lq, Nq, C1]. k: [B, Lk, Nk, C1]. v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. q_lens: [B]. k_lens: [B]. dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying

(
    q,
    k,
    v,
    q_lens=None,
    k_lens=None,
    dropout_p=0.,
    softmax_scale=None,
    q_scale=None,
    causal=False,
    window_size=(-1, -1),
    deterministic=False,
    dtype=torch.bfloat16,
    version=None,
)

Source from the content-addressed store, hash-verified

22
23
24def flash_attention(
25 q,
26 k,
27 v,
28 q_lens=None,
29 k_lens=None,
30 dropout_p=0.,
31 softmax_scale=None,
32 q_scale=None,
33 causal=False,
34 window_size=(-1, -1),
35 deterministic=False,
36 dtype=torch.bfloat16,
37 version=None,
38):
39 """
40 q: [B, Lq, Nq, C1].
41 k: [B, Lk, Nk, C1].
42 v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
43 q_lens: [B].
44 k_lens: [B].
45 dropout_p: float. Dropout probability.
46 softmax_scale: float. The scaling of QK^T before applying softmax.
47 causal: bool. Whether to apply causal attention mask.
48 window_size: (left right). If not (-1, -1), apply sliding window local attention.
49 deterministic: bool. If True, slightly slower and uses more memory.
50 dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
51 """
52 half_dtypes = (torch.float16, torch.bfloat16)
53 assert dtype in half_dtypes
54 assert q.device.type == 'cuda' and q.size(-1) <= 256
55
56 # params
57 b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
58
59 def half(x):
60 return x if x.dtype in half_dtypes else x.to(dtype)
61
62 # preprocess query
63 if q_lens is None:
64 q = half(q.flatten(0, 1))
65 q_lens = torch.tensor(
66 [lq] * b, dtype=torch.int32).to(
67 device=q.device, non_blocking=True)
68 else:
69 q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
70
71 # preprocess key, value
72 if k_lens is None:
73 k = half(k.flatten(0, 1))
74 v = half(v.flatten(0, 1))
75 k_lens = torch.tensor(
76 [lk] * b, dtype=torch.int32).to(
77 device=k.device, non_blocking=True)
78 else:
79 k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
80 v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
81

Callers 6

forwardMethod · 0.85
forwardMethod · 0.85
attentionFunction · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85

Calls 1

halfFunction · 0.70

Tested by

no test coverage detected