MCPcopy
hub / github.com/Wan-Video/Wan2.2 / attention

Function attention

wan/modules/attention.py:133–179  ·  view source on GitHub ↗
(
    q,
    k,
    v,
    q_lens=None,
    k_lens=None,
    dropout_p=0.,
    softmax_scale=None,
    q_scale=None,
    causal=False,
    window_size=(-1, -1),
    deterministic=False,
    dtype=torch.bfloat16,
    fa_version=None,
)

Source from the content-addressed store, hash-verified

131
132
133def attention(
134 q,
135 k,
136 v,
137 q_lens=None,
138 k_lens=None,
139 dropout_p=0.,
140 softmax_scale=None,
141 q_scale=None,
142 causal=False,
143 window_size=(-1, -1),
144 deterministic=False,
145 dtype=torch.bfloat16,
146 fa_version=None,
147):
148 if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
149 return flash_attention(
150 q=q,
151 k=k,
152 v=v,
153 q_lens=q_lens,
154 k_lens=k_lens,
155 dropout_p=dropout_p,
156 softmax_scale=softmax_scale,
157 q_scale=q_scale,
158 causal=causal,
159 window_size=window_size,
160 deterministic=deterministic,
161 dtype=dtype,
162 version=fa_version,
163 )
164 else:
165 if q_lens is not None or k_lens is not None:
166 warnings.warn(
167 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
168 )
169 attn_mask = None
170
171 q = q.transpose(1, 2).to(dtype)
172 k = k.transpose(1, 2).to(dtype)
173 v = v.transpose(1, 2).to(dtype)
174
175 out = torch.nn.functional.scaled_dot_product_attention(
176 q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
177
178 out = out.transpose(1, 2).contiguous()
179 return out

Callers

nothing calls this directly

Calls 2

flash_attentionFunction · 0.85
toMethod · 0.80

Tested by

no test coverage detected