(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask=None, dropout_p=0.0)
| 20 | |
| 21 | |
| 22 | def attention_ref(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask=None, dropout_p=0.0): |
| 23 | head_dim = q.size(-1) |
| 24 | attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim) |
| 25 | if attn_mask is not None: |
| 26 | attn_weights = attn_weights + attn_mask |
| 27 | attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float).to(q.dtype) |
| 28 | attn_weights = torch.dropout(attn_weights, p=dropout_p, train=True) |
| 29 | attn_output = torch.matmul(attn_weights, v) |
| 30 | return attn_output |
| 31 | |
| 32 | |
| 33 | def gen_padded_kwargs(dtype: torch.dtype): |
no test coverage detected
searching dependent graphs…