MCPcopy
hub / github.com/Robbyant/lingbot-world / attention

Function attention

wan/modules/attention.py:132–178  ·  view source on GitHub ↗
(
    q,
    k,
    v,
    q_lens=None,
    k_lens=None,
    dropout_p=0.,
    softmax_scale=None,
    q_scale=None,
    causal=False,
    window_size=(-1, -1),
    deterministic=False,
    dtype=torch.bfloat16,
    fa_version=None,
)

Source from the content-addressed store, hash-verified

130
131
132def attention(
133 q,
134 k,
135 v,
136 q_lens=None,
137 k_lens=None,
138 dropout_p=0.,
139 softmax_scale=None,
140 q_scale=None,
141 causal=False,
142 window_size=(-1, -1),
143 deterministic=False,
144 dtype=torch.bfloat16,
145 fa_version=None,
146):
147 if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
148 return flash_attention(
149 q=q,
150 k=k,
151 v=v,
152 q_lens=q_lens,
153 k_lens=k_lens,
154 dropout_p=dropout_p,
155 softmax_scale=softmax_scale,
156 q_scale=q_scale,
157 causal=causal,
158 window_size=window_size,
159 deterministic=deterministic,
160 dtype=dtype,
161 version=fa_version,
162 )
163 else:
164 if q_lens is not None or k_lens is not None:
165 warnings.warn(
166 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
167 )
168 attn_mask = None
169
170 q = q.transpose(1, 2).to(dtype)
171 k = k.transpose(1, 2).to(dtype)
172 v = v.transpose(1, 2).to(dtype)
173
174 out = torch.nn.functional.scaled_dot_product_attention(
175 q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
176
177 out = out.transpose(1, 2).contiguous()
178 return out

Callers 1

forwardMethod · 0.90

Calls 2

flash_attentionFunction · 0.85
toMethod · 0.80

Tested by

no test coverage detected