MCPcopy
hub / github.com/hpcaitech/ColossalAI / attention_ref

Function attention_ref

tests/test_shardformer/test_flash_attention.py:22–30  ·  view source on GitHub ↗
(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask=None, dropout_p=0.0)

Source from the content-addressed store, hash-verified

20
21
22def attention_ref(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask=None, dropout_p=0.0):
23 head_dim = q.size(-1)
24 attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim)
25 if attn_mask is not None:
26 attn_weights = attn_weights + attn_mask
27 attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float).to(q.dtype)
28 attn_weights = torch.dropout(attn_weights, p=dropout_p, train=True)
29 attn_output = torch.matmul(attn_weights, v)
30 return attn_output
31
32
33def gen_padded_kwargs(dtype: torch.dtype):

Callers 1

check_attn_funcFunction · 0.85

Calls 2

sizeMethod · 0.45
toMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…