Attention operator in python Parameters ---------- q : np.ndarray Query tensor with shape [batch, seq_length, num_heads, head_dim] in the layout specified by `layout`. k : np.ndarray Key tensor with shape [batch, seq_length_kv, num_kv_heads, head_dim] in the
(
q: np.ndarray,
k: np.ndarray,
v: np.ndarray,
bias: np.ndarray | None,
qk_scale: float,
causal: str,
window_size: int | None = None,
layout: str = "BSNH",
)
| 23 | |
| 24 | |
| 25 | def attention_python( |
| 26 | q: np.ndarray, |
| 27 | k: np.ndarray, |
| 28 | v: np.ndarray, |
| 29 | bias: np.ndarray | None, |
| 30 | qk_scale: float, |
| 31 | causal: str, |
| 32 | window_size: int | None = None, |
| 33 | layout: str = "BSNH", |
| 34 | ): # pylint: disable=too-many-arguments, too-many-locals, invalid-name |
| 35 | """Attention operator in python |
| 36 | |
| 37 | Parameters |
| 38 | ---------- |
| 39 | q : np.ndarray |
| 40 | Query tensor with shape [batch, seq_length, num_heads, head_dim] in the layout specified by |
| 41 | `layout`. |
| 42 | k : np.ndarray |
| 43 | Key tensor with shape [batch, seq_length_kv, num_kv_heads, head_dim] in the layout specified |
| 44 | by `layout`. |
| 45 | v : np.ndarray |
| 46 | Value tensor with shape [batch, seq_length_kv, num_kv_heads, head_dim_v] in the layout |
| 47 | specified by `layout`. |
| 48 | bias : np.ndarray |
| 49 | Bias tensor with shape [batch, num_heads, seq_length, seq_length] |
| 50 | qk_scale : float |
| 51 | Scale factor for the query-key product. |
| 52 | causal : str |
| 53 | The type of causal mask to apply. Can be "none", "TopLeft", or "BottomRight". |
| 54 | window_size : Optional[int] |
| 55 | The window size for the causal mask. |
| 56 | layout : str |
| 57 | The layout of the input tensors, e.g. "BSNH" or "BNSH". |
| 58 | |
| 59 | Returns |
| 60 | ------- |
| 61 | np.ndarray |
| 62 | The output tensor with shape [batch, seq_length, num_heads, head_dim_v] in the layout |
| 63 | specified by `layout`. |
| 64 | """ |
| 65 | assert layout in ["BSNH", "BNSH", "SBNH"] |
| 66 | |
| 67 | dim_b = layout.find("B") |
| 68 | dim_s = layout.find("S") |
| 69 | dim_n = layout.find("N") |
| 70 | dim_h = layout.find("H") |
| 71 | |
| 72 | q = q.transpose(dim_b, dim_n, dim_s, dim_h) # b, n, s, h |
| 73 | k = k.transpose(dim_b, dim_n, dim_s, dim_h) # b, n, s_kv, h |
| 74 | kt = k.transpose(0, 1, 3, 2) # b, n, h, s_kv |
| 75 | v = v.transpose(dim_b, dim_n, dim_s, dim_h) |
| 76 | |
| 77 | num_heads = q.shape[1] |
| 78 | num_kv_heads = k.shape[1] |
| 79 | s = q.shape[2] |
| 80 | s_kv = k.shape[2] |
| 81 | |
| 82 | if num_heads != num_kv_heads: |
nothing calls this directly
no test coverage detected
searching dependent graphs…