MCPcopy
hub / github.com/MeiGen-AI/InfiniteTalk / flash_attention

Function flash_attention

wan/modules/attention.py:33–139  ·  view source on GitHub ↗

q: [B, Lq, Nq, C1]. k: [B, Lk, Nk, C1]. v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. q_lens: [B]. k_lens: [B]. dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying

(
    q,
    k,
    v,
    q_lens=None,
    k_lens=None,
    dropout_p=0.,
    softmax_scale=None,
    q_scale=None,
    causal=False,
    window_size=(-1, -1),
    deterministic=False,
    dtype=torch.bfloat16,
    version=None,
)

Source from the content-addressed store, hash-verified

31
32
33def flash_attention(
34 q,
35 k,
36 v,
37 q_lens=None,
38 k_lens=None,
39 dropout_p=0.,
40 softmax_scale=None,
41 q_scale=None,
42 causal=False,
43 window_size=(-1, -1),
44 deterministic=False,
45 dtype=torch.bfloat16,
46 version=None,
47):
48 """
49 q: [B, Lq, Nq, C1].
50 k: [B, Lk, Nk, C1].
51 v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
52 q_lens: [B].
53 k_lens: [B].
54 dropout_p: float. Dropout probability.
55 softmax_scale: float. The scaling of QK^T before applying softmax.
56 causal: bool. Whether to apply causal attention mask.
57 window_size: (left right). If not (-1, -1), apply sliding window local attention.
58 deterministic: bool. If True, slightly slower and uses more memory.
59 dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
60 """
61 half_dtypes = (torch.float16, torch.bfloat16)
62 assert dtype in half_dtypes
63 assert q.device.type == 'cuda' and q.size(-1) <= 256
64
65 # params
66 b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
67
68 def half(x):
69 return x if x.dtype in half_dtypes else x.to(dtype)
70
71 # preprocess query
72 if q_lens is None:
73 q = half(q.flatten(0, 1))
74 q_lens = torch.tensor(
75 [lq] * b, dtype=torch.int32).to(
76 device=q.device, non_blocking=True)
77 else:
78 q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
79
80 # preprocess key, value
81 if k_lens is None:
82 k = half(k.flatten(0, 1))
83 v = half(v.flatten(0, 1))
84 k_lens = torch.tensor(
85 [lk] * b, dtype=torch.int32).to(
86 device=k.device, non_blocking=True)
87 else:
88 k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
89 v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
90

Callers 8

forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
attentionFunction · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85

Calls 1

halfFunction · 0.70

Tested by

no test coverage detected