MCPcopy
hub / github.com/Robbyant/lingbot-world / flash_attention

Function flash_attention

wan/modules/attention.py:23–129  ·  view source on GitHub ↗

q: [B, Lq, Nq, C1]. k: [B, Lk, Nk, C1]. v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. q_lens: [B]. k_lens: [B]. dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying

(
    q,
    k,
    v,
    q_lens=None,
    k_lens=None,
    dropout_p=0.,
    softmax_scale=None,
    q_scale=None,
    causal=False,
    window_size=(-1, -1),
    deterministic=False,
    dtype=torch.bfloat16,
    version=None,
)

Source from the content-addressed store, hash-verified

21
22
23def flash_attention(
24 q,
25 k,
26 v,
27 q_lens=None,
28 k_lens=None,
29 dropout_p=0.,
30 softmax_scale=None,
31 q_scale=None,
32 causal=False,
33 window_size=(-1, -1),
34 deterministic=False,
35 dtype=torch.bfloat16,
36 version=None,
37):
38 """
39 q: [B, Lq, Nq, C1].
40 k: [B, Lk, Nk, C1].
41 v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
42 q_lens: [B].
43 k_lens: [B].
44 dropout_p: float. Dropout probability.
45 softmax_scale: float. The scaling of QK^T before applying softmax.
46 causal: bool. Whether to apply causal attention mask.
47 window_size: (left right). If not (-1, -1), apply sliding window local attention.
48 deterministic: bool. If True, slightly slower and uses more memory.
49 dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
50 """
51 half_dtypes = (torch.float16, torch.bfloat16)
52 assert dtype in half_dtypes
53 assert q.device.type == 'cuda' and q.size(-1) <= 256
54
55 # params
56 b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
57
58 def half(x):
59 return x if x.dtype in half_dtypes else x.to(dtype)
60
61 # preprocess query
62 if q_lens is None:
63 q = half(q.flatten(0, 1))
64 q_lens = torch.tensor(
65 [lq] * b, dtype=torch.int32).to(
66 device=q.device, non_blocking=True)
67 else:
68 q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
69
70 # preprocess key, value
71 if k_lens is None:
72 k = half(k.flatten(0, 1))
73 v = half(v.flatten(0, 1))
74 k_lens = torch.tensor(
75 [lk] * b, dtype=torch.int32).to(
76 device=k.device, non_blocking=True)
77 else:
78 k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
79 v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
80

Callers 14

attentionFunction · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85

Calls 4

sizeMethod · 0.80
toMethod · 0.80
typeMethod · 0.80
halfFunction · 0.70

Tested by

no test coverage detected